diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfba4b6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.DS_Store +**/.DS_Store \ No newline at end of file diff --git a/training/CODE_OF_CONDUCT.md b/training/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..3232ed6 --- /dev/null +++ b/training/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/training/CONTRIBUTING.md b/training/CONTRIBUTING.md new file mode 100644 index 0000000..7766384 --- /dev/null +++ b/training/CONTRIBUTING.md @@ -0,0 +1,173 @@ +# Contributing to torchtune +We want to make contributing to this project as easy and transparent as possible. + +  + +## Dev install +You should first [fork](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo) the torchtune repository +and then clone your forked repository. Make sure to keep your fork in sync with the torchtune repository over time. + +```git clone https://github.com//torchtune.git``` + +Then navigate into the newly cloned repo and install dependencies needed for development. + +**Step 1:** [Install PyTorch](https://pytorch.org/get-started/locally/). torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. + + +**Step 2:** Install all the additional dependencies and dev dependencies in the local repo: + +``` +cd torchtune +pip install -e ".[dev]" +``` + +  + +## Contributing workflow +We actively welcome your pull requests. + +1. Create your new branch from `main` in your forked repo, with a name describing the work you're completing e.g. `add-feature-x`. +2. If you've added code that should be tested, add tests. Ensure all tests pass. See the [testing section](#testing) for more information. +3. If you've changed APIs, [update the documentation](#updating-documentation). +4. Make sure your [code lints](#coding-style). +5. If you haven't already, complete the [Contributor License Agreement ("CLA")](#contributor-license-agreement-cla) + +  + +## Testing +torchtune contains three different types of tests: unit tests, recipe tests, and regression tests. These tests are distinguished by their complexity and the resources they require to run. Recipe tests and regression tests are explicitly marked via pytest.mark decorators and both require S3 access to download the requisite assets. + +- **Unit tests** + - These should be minimal tests runnable without remote access. (No large models, no downloading weights). Unit tests should be under [tests/torchtune](https://github.com/pytorch/torchtune/tree/main/tests/torchtune). + - All unit tests can be run via ```pytest tests```. +- **Recipe tests** + - These are relatively small-scale integration tests for running our recipes. These include + both single-device recipes and distributed recipes. In the latter case, tests should be marked with the `@gpu_test` decorator to indicate how many GPUs they need to run. + - Recipe tests require remote access as (small) model weights will be downloaded from S3 to run them. + - Recipe tests are found under [tests/recipes](https://github.com/pytorch/torchtune/tree/main/tests/recipes) and should be marked with the `@pytest.mark.integration_test` decorator. + - To run only recipe tests, you can run `pytest tests -m integration_test`. +- **Regression tests** + - These are the most heavyweight tests in the repo. They involve building a full model (i.e. 7B size or larger), then running some finetune and/or evaluation via a combination of tune CLI commands. Whereas an individual recipe test runtime is generally still O(seconds), integration tests should be O(minutes) or greater. Like recipe tests, regression tests also require S3 access. + - Regression tests are found under [tests/regression_tests](https://github.com/pytorch/torchtune/tree/main/tests/regression_tests) and should be marked with the `@pytest.mark.slow_integration_test` decorator. + - To run only regression tests, you can use the command `pytest tests -m slow_integration_test`. + +Whenever running tests in torchtune, favor using the command line flags as much as possible (e.g. run `pytest tests -m integration_test` over `pytest tests/recipes`). This is because (a) the default behavior is to run unit tests only (so you will miss recipe tests without the flag), and (b) using the flags ensures pytest will automatically download any remote assets needed for your test run. + +Note that the above flags can be combined with other pytest flags, so e.g. `pytest tests -m integration_test -k 'test_loss'` will run only recipe tests matching the substring `test_loss`. + +  + +## Updating documentation +Each API and class should be clearly documented. Well-documented code is easier to review and understand/extend. All documentation is contained in the [docs directory](docs/source): + +* All files following the pattern `api_ref_*` document top-level APIs. +* All files under the [deep dives directory](docs/source/deep_dives) contain "deep-dive" tutorials +* All files under the [tutorials directory](docs/source/tutorials) contain regular tutorials + +Documentation is written in [RST](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) format. + +### Adding a new class/method to the API References +Once you've added an API that is meant to be exposed publically, you should add it to the appropriate rst file. For example, any new API within the [configs/](torchtune/configs) +directory should be added to `api_ref_configs.rst`, [data/](torchtune/data) should be added to `api_ref_data.rst`, [datasets](torchtune/datasets) should be added to +`api_ref_datasets.rst`, and so on. To add, it's as simple as adding the name of the exposed API somewhere in the appropriate RST file. + +All code written within the docstring of the class or method will be correctly rendered there. + +> Note: Our RST theme expects code to be specified using double backticks instead of single. Eg: ``hidden_dim``. Single backticks will be rendered as italics instead of as "code". + +### Building docs + +All documentation is built for each PR and contains a preview on the PR. However, this takes awhile (~8 minutes) and you should first build docs from your local machine. + +From the [docs/](docs) directory: + +1. Install dependencies: + +``` +pip install -r requirements.txt +``` + +2. Run make command: + +``` +make html +# Now open build/html/index.html +``` + +To avoid building the examples (which execute python code and can take time) you +can use `make html-noplot`. To build a subset of specific examples instead of +all of them, you can use a regex like `EXAMPLES_PATTERN="plot_the_best_example*" +make html`. + +If the doc build starts failing for a weird reason, try `make clean`. + +#### Serving docs locally (if building from a GPU env) + +If you're developing locally, you can just open the generated `index.html` file in your browser. + +If instead you're using a remote machine, you can use a combination of a simple python HTTP server and port forwarding to serve the docs locally. This allows you to iterate on the documentation much more quickly than relying on PR previews. + +To do so, after following the above doc build steps, run the following from the `docs/build/html` folder: + +``` +python -m http.server 8000 # or any free port +``` + +This will open up a simple HTTP server serving the files in the build directory. If this is done on a remote machine, you can set up port forwarding from your local machine to access the server, for example: + +``` +ssh -L 9000:localhost:8000 $REMOTE_DEV_HOST +``` + +Now, you can navigate to `localhost:9000` on your local machine to view the rendered documentation. + +  + +## Coding Style +`torchtune` uses pre-commit hooks to ensure style consistency and prevent common mistakes. Enable it by: + +``` +pre-commit install +``` + +After this pre-commit hooks will be run before every commit. + +You can also run this manually on every file using: + +``` +pre-commit run --all-files +``` + +  + +## Best Practices + +This section captures some best practices for contributing code to torchtune. Following these will make PR reviews easier. + +- **Modular Blocks instead of Monolithic Classes**. Stuffing all of the logic into a single class limits readability and makes it hard to reuse logic. Think about breaking the implementation into self-contained blocks which can be used independently from a given model. For example, attention mechanisms, embedding classes, transformer layers etc. +- **Say no to Implementation Inheritance**. You really don’t need it AND it makes the code much harder to understand or refactor since the logic is spread across many files/classes. Where needed, consider using Protocols. +- **Clean Interfaces**. There’s nothing more challenging than reading through functions/constructors with ~100 parameters. Think carefully about what needs to be exposed to the user and don’t hesitate to hard-code parameters until there is a need to make them configurable. +- **Intrusive Configs**. Config objects should not intrude into the class implementation. Configs should interact with these classes through cleanly defined builder functions which convert the config into flat parameters needed to instantiate an object. +- **Limit Generalization**. Attempting to generalize code before this is needed unnecessarily complicates implementations - you are anticipating use cases you don’t know a lot about. When you actually need to generalize a component, think about whether it’s worth it to complicate a given interface to stuff in more functionality. Don’t be afraid of code duplication if it makes things easier to read. +- **Value Checks and Asserts**. Don’t check values in higher level modules - defer the checks to the modules where the values are actually used. This helps reduce the number of raise statements in code which generally hurts readability, but are critical for correctness. + +  + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. + +  + +## License +By contributing to torchtune, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. + +  + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +  diff --git a/training/LICENSE b/training/LICENSE new file mode 100644 index 0000000..bc559a9 --- /dev/null +++ b/training/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright 2024 Meta + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice,this list +of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may +be used to endorse or promote products derived from this software without specific +prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. diff --git a/training/MANIFEST.in b/training/MANIFEST.in new file mode 100644 index 0000000..71d1a8c --- /dev/null +++ b/training/MANIFEST.in @@ -0,0 +1 @@ +prune tests # Remove all testing files from final dist/ diff --git a/training/README.md b/training/README.md new file mode 100644 index 0000000..c51c1c1 --- /dev/null +++ b/training/README.md @@ -0,0 +1,337 @@ +# torchtune (Modified for OpenScholar) + +[![Unit Test](https://github.com/pytorch/torchtune/actions/workflows/unit_test.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtune/actions/workflows/unit_test.yaml) +![Recipe Integration Test](https://github.com/pytorch/torchtune/actions/workflows/recipe_test.yaml/badge.svg) +[![](https://dcbadge.vercel.app/api/server/4Xsdn8Rr9Q?style=flat)](https://discord.gg/4Xsdn8Rr9Q) + +[**Introduction**](#introduction) | [**Installation**](#installation) | [**Get Started**](#get-started) | [**Documentation**](https://pytorch.org/torchtune/main/index.html) | [**Design Principles**](#design-principles) | [**Community Contributions**](#community-contributions) | [**License**](#license) + +  + +> [!NOTE] +> **July 2024**: torchtune has updated model weights for Llama3.1 in source and nightly builds! Check out our configs for both the [8B and 70B versions](recipes/configs/llama3_1/) of the model. LoRA, QLoRA, and full finetune methods are supported. Support for QLoRA 405B will be added soon. + +## Introduction + +torchtune is a PyTorch-native library for easily authoring, fine-tuning and experimenting with LLMs. We're excited to announce our alpha release! + +torchtune provides: + +- Native-PyTorch implementations of popular LLMs using composable and modular building blocks +- Easy-to-use and hackable training recipes for popular fine-tuning techniques (LoRA, QLoRA) - no trainers, no frameworks, just PyTorch! +- YAML configs for easily configuring training, evaluation, quantization or inference recipes +- Built-in support for many popular dataset formats and prompt templates to help you quickly get started with training + +torchtune focuses on integrating with popular tools and libraries from the ecosystem. These are just a few examples, with more under development: + +- [Hugging Face Hub](https://huggingface.co/docs/hub/en/index) for [accessing model weights](torchtune/_cli/download.py) +- [EleutherAI's LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for [evaluating](recipes/eleuther_eval.py) trained models +- [Hugging Face Datasets](https://huggingface.co/docs/datasets/en/index) for [access](torchtune/datasets/_instruct.py) to training and evaluation datasets +- [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html) for distributed training +- [torchao](https://github.com/pytorch-labs/ao) for lower precision dtypes and [post-training quantization](recipes/quantize.py) techniques +- [Weights & Biases](https://wandb.ai/site) for [logging](https://pytorch.org/torchtune/main/deep_dives/wandb_logging.html) metrics and checkpoints, and tracking training progress +- [ExecuTorch](https://pytorch.org/executorch-overview) for [on-device inference](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#optional-finetuning) using fine-tuned models +- [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) for low memory optimizers for our [single-device recipes](recipes/configs/llama2/7B_full_low_memory.yaml) + +  + +### Models + +torchtune currently supports the following models. + +| Model | Sizes | +|-----------------------------------------------|-----------| +| [Llama3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) | 8B, 70B [[models](torchtune/models/llama3_1/_model_builders.py), [configs](recipes/configs/llama3_1/)] | +| [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] | +| [Llama2](https://llama.meta.com/llama2/) | 7B, 13B, 70B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] | +| [Code-Llama2](https://ai.meta.com/blog/code-llama-large-language-model-coding/) | 7B, 13B, 70B [[model](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | +| [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] | +| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B, 7B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] | +| [Microsoft Phi3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) | Mini [[model](torchtune/models/phi3/), [configs](recipes/configs/phi3/)] +| [Qwen2](https://qwenlm.github.io/blog/qwen2/) | 0.5B, 1.5B, 7B [[model](torchtune/models/qwen2/), [configs](recipes/configs/qwen2/)] + +We're always adding new models, but feel free to [file an Issue](https://github.com/pytorch/torchtune/issues/new) if there's a new one you would love to see in torchtune! + +  + +### Fine-tuning recipes + +torchtune provides the following fine-tuning recipes. + +| Training | Fine-tuning Method | +|------------------------------------|------------------------------------| +| Distributed Training [1 to 8 GPUs] | Full [[code](recipes/full_finetune_distributed.py), [example](recipes/configs/llama3/8B_full.yaml)], LoRA [[code](recipes/lora_finetune_distributed.py), [example](recipes/configs/llama3/8B_lora.yaml)] | +| Single Device / Low Memory [1 GPU] | Full [[code](recipes/full_finetune_single_device.py), [example](recipes/configs/llama3/8B_full_single_device.yaml)], LoRA + QLoRA [[code](recipes/lora_finetune_single_device.py), [example](recipes/configs/llama3/8B_lora_single_device.yaml)] | +| Single Device [1 GPU] | DPO [[code](recipes/lora_dpo_single_device.py), [example](recipes/configs/llama2/7B_lora_dpo_single_device.yaml)], RLHF with PPO [[code](recipes/ppo_full_finetune_single_device.py), [example](recipes/configs/mistral/7B_full_ppo_low_memory.yaml)] + +  + + +Memory efficiency is important to us. All of our recipes are tested on a variety of setups including commodity GPUs with 24GB of VRAM as well as beefier options found in data centers. + +Single-GPU recipes expose a number of memory optimizations that aren't available in the distributed versions. These include support for low-precision optimizers from [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and fusing optimizer step with backward to reduce memory footprint from the gradients (see example [config](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_low_memory.yaml)). For memory-constrained setups, we recommend using the single-device configs as a starting point. + +This table captures the peak memory usage and training speed for recipes in torchtune. + +| Example HW Resources | Finetuning Method | Model | Setting | Peak Memory per GPU (GB) | Training Speed (tokens/sec) | +|:-:|:-:|:-:|:-:|:-:|:-:| +| 1 x RTX 4090 | QLoRA ** | Llama2-7B | Batch Size = 4, Seq Length = 2048 | 12.3 GB | 3155 | +| 1 x RTX 4090 | LoRA | Llama2-7B | Batch Size = 4, Seq Length = 2048 | 21.3 GB | 2582 | +| 2 x RTX 4090 | LoRA | Llama2-7B | Batch Size = 4, Seq Length = 2048 | 16.2 GB | 2768 | +| 1 x RTX 4090 | Full finetune * | Llama2-7B | Batch Size = 4, Seq Length = 2048 | 24.1 GB | 702 | +| 4 x RTX 4090 | Full finetune | Llama2-7B | Batch Size = 4, Seq Length = 2048 | 24.1 GB | 1388 | +| 8 x A100 | LoRA | Llama2-70B | Batch Size = 4, Seq Length = 4096 | 26.4 GB | 3384 | +| 8 x A100 | Full Finetune * | Llama2-70B | Batch Size = 4, Seq Length = 4096 | 70.4 GB | 2032 | + +*= Uses PagedAdamW from bitsandbytes + +**= Uses torch compile + + +  + +## Llama3 and Llama3.1 + +torchtune supports fine-tuning for the Llama3 8B and 70B size models. We currently support LoRA, QLoRA and full fine-tune on a single GPU as well as LoRA and full fine-tune on multiple devices for the 8B model, and LoRA on multiple devices for the 70B model. For all the details, take a look at our [tutorial](https://pytorch.org/torchtune/main/tutorials/llama3.html). + +> [!NOTE] +> Our Llama3 and Llama3.1 LoRA and QLoRA configs default to the instruct fine-tuned models. This is because not all special token embeddings are initialized in the base 8B and 70B models. + +In our initial experiments for Llama3-8B, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training. + +### Single GPU + +LoRA 8B + +```bash +tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device +``` + +QLoRA 8B + +```bash +tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device +``` + +Full 8B + +```bash +tune run full_finetune_single_device --config llama3_1/8B_full_single_device +``` + +### Multi GPU + +Full 8B + +```bash +tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full +``` + +LoRA 8B + +```bash +tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora +``` + +LoRA 70B + +Note that the download command for the Meta-Llama3 70B model slightly differs from download commands for the 8B models. This is because we use the HuggingFace [safetensor](https://huggingface.co/docs/safetensors/en/index) model format to load the model. To download the 70B model, run +```bash +tune download meta-llama/Meta-Llama-3.1-70b --hf-token <> --output-dir /tmp/Meta-Llama-3.1-70b --ignore-patterns "original/consolidated*" +``` + +Then, a finetune can be kicked off: + +```bash +tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_1/70B_lora.yaml +``` + +You can find a full list of all our Llama3 configs [here](recipes/configs/llama3) and Llama3.1 configs [here.](recipes/configs/llama3_1) + + +  + +--- + +## Installation + +**Step 1:** [Install PyTorch](https://pytorch.org/get-started/locally/). torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. For fine-tuning the multimodal LLMs available in the repo, you'll need to install torchvision as well. + +``` +# Install stable version of PyTorch using pip +pip install torch torchvision + +# Nightly install for latest features +pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 +``` + +**Step 2:** The latest stable version of torchtune is hosted on PyPI and can be downloaded with the following command: + +```bash +pip install torchtune +``` + +To confirm that the package is installed correctly, you can run the following command: + +```bash +tune --help +``` + +And should see the following output: + +```bash +usage: tune [-h] {ls,cp,download,run,validate} ... + +Welcome to the torchtune CLI! + +options: + -h, --help show this help message and exit + +... +``` + +You can also install the latest and greatest torchtune has to offer by [installing a nightly build](https://pytorch.org/torchtune/main/install.html). + +  + +--- + +## Get Started + +To get started with fine-tuning your first LLM with torchtune, see our tutorial on [fine-tuning Llama2 7B](https://pytorch.org/torchtune/main/tutorials/first_finetune_tutorial.html). Our [end-to-end workflow](https://pytorch.org/torchtune/main/tutorials/e2e_flow.html) tutorial will show you how to evaluate, quantize and run inference with this model. The rest of this section will provide a quick overview of these steps with Llama2. + +  + +### Downloading a model + +Follow the instructions on the official [`meta-llama`](https://huggingface.co/meta-llama) repository to ensure you have access to the official Llama model weights. Once you have confirmed access, you can run the following command to download the weights to your local machine. This will also download the tokenizer model and a responsible use guide. + + +### Llama3 download +```bash +tune download meta-llama/Meta-Llama-3-8B \ +--output-dir /tmp/Meta-Llama-3-8B \ +--hf-token \ +``` + +> [!Tip] +> Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command in order to validate your access. You can find your token at https://huggingface.co/settings/tokens + +  + +### Running fine-tuning recipes + +Llama3 8B + LoRA on single GPU: + +```bash +tune run lora_finetune_single_device --config llama2/7B_lora_single_device +``` + +For distributed training, tune CLI integrates with [torchrun](https://pytorch.org/docs/stable/elastic/run.html). +Llama3 8B + LoRA on two GPUs: + +```bash +tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full +``` + +> [!Tip] +> Make sure to place any torchrun commands **before** the recipe specification. Any CLI args after this will override the config and not impact distributed training. + +  + +### Modify Configs + +There are two ways in which you can modify configs: + +**Config Overrides** + +You can easily overwrite config properties from the command-line: + +```bash +tune run lora_finetune_single_device \ +--config llama2/7B_lora_single_device \ +batch_size=8 \ +enable_activation_checkpointing=True \ +max_steps_per_epoch=128 +``` + +**Update a Local Copy** + +You can also copy the config to your local directory and modify the contents directly: + +```bash +tune cp llama2/7B_full ./my_custom_config.yaml +Copied to ./7B_full.yaml +``` + +Then, you can run your custom recipe by directing the `tune run` command to your local files: + +```bash +tune run full_finetune_distributed --config ./my_custom_config.yaml +``` + +  + +Check out `tune --help` for all possible CLI commands and options. For more information on using and updating configs, take a look at our [config deep-dive](https://pytorch.org/torchtune/main/deep_dives/configs.html). + +  + +## Design Principles + +torchtune embodies PyTorch’s design philosophy [[details](https://pytorch.org/docs/stable/community/design.html)], especially "usability over everything else". + +### Native PyTorch + +torchtune is a native-PyTorch library. While we provide integrations with the surrounding ecosystem (e.g. Hugging Face Datasets, EleutherAI Eval Harness), all of the core functionality is written in PyTorch. + +### Simplicity and Extensibility + +torchtune is designed to be easy to understand, use and extend. + +- Composition over implementation inheritance - layers of inheritance for code re-use makes the code hard to read and extend +- No training frameworks - explicitly outlining the training logic makes it easy to extend for custom use cases +- Code duplication is preferred over unnecessary abstractions +- Modular building blocks over monolithic components + +### Correctness + +torchtune provides well-tested components with a high-bar on correctness. The library will never be the first to provide a feature, but available features will be thoroughly tested. We provide + +- Extensive unit-tests to ensure component-level numerical parity with reference implementations +- Checkpoint-tests to ensure model-level numerical parity with reference implementations +- Integration tests to ensure recipe-level performance parity with reference implementations on standard benchmarks + +  + +## Community Contributions + +We really value our community and the contributions made by our wonderful users. We'll use this section to call out some of these contributions! If you'd like to help out as well, please see the [CONTRIBUTING](CONTRIBUTING.md) guide. + +- [@SalmanMohammadi](https://github.com/salmanmohammadi) for adding a comprehensive end-to-end recipe for [Reinforcement Learning from Human Feedback (RLHF)](recipes/ppo_full_finetune_single_device.py) finetuning with PPO to torchtune +- [@fyabc](https://github.com/fyabc) for adding Qwen2 models, tokenizer, and recipe integration to torchtune +- [@solitude-alive](https://github.com/solitude-alive) for adding the [Gemma 2B model](torchtune/models/gemma/) to torchtune, including recipe changes, numeric validations of the models and recipe correctness +- [@yechenzhi](https://github.com/yechenzhi) for adding [Direct Preference Optimization (DPO)](recipes/lora_dpo_single_device.py) to torchtune, including the recipe and config along with correctness checks + + +  + +## Acknowledgements + +The Llama2 code in this repository is inspired by the original [Llama2 code](https://github.com/meta-llama/llama/blob/main/llama/model.py). + +We want to give a huge shout-out to EleutherAI, Hugging Face and Weights & Biases for being wonderful collaborators and for working with us on some of these integrations within torchtune. + +We also want to acknowledge some awesome libraries and tools from the ecosystem: +- [gpt-fast](https://github.com/pytorch-labs/gpt-fast) for performant LLM inference techniques which we've adopted OOTB +- [llama recipes](https://github.com/meta-llama/llama-recipes) for spring-boarding the llama2 community +- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for bringing several memory and performance based techniques to the PyTorch ecosystem +- [@winglian](https://github.com/winglian/) and [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) for early feedback and brainstorming on torchtune's design and feature set. +- [lit-gpt](https://github.com/Lightning-AI/litgpt) for pushing the LLM fine-tuning community forward. +- [HF TRL](https://github.com/huggingface/trl) for making reward modeling more accessible to the PyTorch community. + +  + + +## License + +torchtune is released under the [BSD 3 license](./LICENSE). However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models. diff --git a/training/docs/Makefile b/training/docs/Makefile new file mode 100644 index 0000000..7ba69c5 --- /dev/null +++ b/training/docs/Makefile @@ -0,0 +1,43 @@ +# Minimal makefile for Sphinx documentation +# + +ifneq ($(EXAMPLES_PATTERN),) + EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)" +endif + +# You can set these variables from the command line. +SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) +SPHINXBUILD = sphinx-build +SPHINXPROJ = torchtune +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +docset: html + doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/vision/ --force $(BUILDDIR)/html/ + + # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. + cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png + convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png + +html-noplot: # Avoids running the gallery examples, which may take time + $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +clean: + rm -rf $(BUILDDIR)/* + rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery + rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery + rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery + rm -rf $(SOURCEDIR)/generated/ # autosummary + +.PHONY: help Makefile docset + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/training/docs/license_header.txt b/training/docs/license_header.txt new file mode 100644 index 0000000..dada37c --- /dev/null +++ b/training/docs/license_header.txt @@ -0,0 +1,5 @@ +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. diff --git a/training/docs/requirements.txt b/training/docs/requirements.txt new file mode 100644 index 0000000..99c4179 --- /dev/null +++ b/training/docs/requirements.txt @@ -0,0 +1,7 @@ +sphinx-gallery>0.11 +sphinx==5.0.0 +sphinx_design +sphinx_copybutton +sphinx-tabs +matplotlib +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/training/docs/source/_static/css/custom_torchtune.css b/training/docs/source/_static/css/custom_torchtune.css new file mode 100644 index 0000000..dc81626 --- /dev/null +++ b/training/docs/source/_static/css/custom_torchtune.css @@ -0,0 +1,187 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* sphinx-design styles for cards/tabs */ + + +:root { + --sd-color-info: #ee4c2c; + --sd-color-primary: #6c6c6d; + --sd-color-primary-highlight: #f3f4f7; + --sd-color-card-border-hover: #ee4c2c; + --sd-color-card-border: #f3f4f7; + --sd-color-card-background: #fff; + --sd-color-card-text: inherit; + --sd-color-card-header: transparent; + --sd-color-card-footer: transparent; + --sd-color-tabs-label-active: #ee4c2c; + --sd-color-tabs-label-hover: #ee4c2c; + --sd-color-tabs-label-inactive: #6c6c6d; + --sd-color-tabs-underline-active: #ee4c2c; + --sd-color-tabs-underline-hover: #fabdbd; + --sd-color-tabs-underline-inactive: transparent; + --sd-color-tabs-overline: rgb(222, 222, 222); + --sd-color-tabs-underline: rgb(222, 222, 222); +} + +.sd-text-info { + color: #ee4c2c; +} + +.sd-card-img-top { + background: #ee4c2c; + height: 5px !important; +} + +.sd-card { + position: relative; + background-color: #fff; + opacity: 1.0; + border-radius: 0px; + width: 30%; + border: none; + padding-bottom: 0px; +} + + +.sd-card-img:hover { + opacity: 1.0; + background-color: #f3f4f7; +} + + +.sd-card:after { + display: block; + opacity: 1; + content: ''; + border-bottom: solid 1px #ee4c2c; + background-color: #fff; + transform: scaleX(0); + transition: transform .250s ease-in-out; + transform-origin: 0% 50%; +} + +.sd-card:hover { + background-color: #fff; + opacity: 1; + border-top: 1px solid #f3f4f7; + border-left: 1px solid #f3f4f7; + border-right: 1px solid #f3f4f7; +} + +.sd-card:hover:after { + transform: scaleX(1); +} + +.card-prerequisites:hover { + transition: none; + border: none; +} + +.card-prerequisites:hover:after { + transition: none; + transform: none; +} + +.card-prerequisites:after { + display: block; + content: ''; + border-bottom: none; + background-color: #fff; + transform: none; + transition: none; + transform-origin: none; +} + + +details.sd-dropdown { + font-weight: 300; + width: auto; +} + +details.sd-dropdown:after { + border: none; + transition: none; +} + +details.sd-dropdown:hover { + border: none; + transition: none; +} + +details.sd-dropdown .sd-summary-content { + font-weight: 300; +} + +details.sd-dropdown .highlight .n { + font-weight: normal; +} + +.et-page-column1 { + float: left; + width: 70%; + font-size: 1rem; +} + +.et-page-column2 { + float: right; + padding-top: 40px; + padding-left: 60px; + padding-right: 60px; + padding-bottom: 60px; + width: 30%; +} + +.et-page-column-row:after { + content: ""; + display: table; + clear: both; +} + +/* For screens smaller than 768px (typical mobile devices) */ +@media screen and (max-width: 768px) { + .et-page-column1, .et-page-column2 { + float: none; /* Remove floats */ + width: 100%; /* Full width for both columns */ + padding: 0; + font-size: 1rem; + } + + .et-page-column2 img { + display: none; + } + .et-page-column-row:after { + content: ""; + display: table; + clear: both; + } +} + +article.pytorch-article .class .method dt { + border-top: none; +} + +article.pytorch-article .class .simple dt { + border-top: none; +} + +article.pytorch-article .function dt.sig { + border-top: none; +} + +/* Fix for Sphinx gallery thumbnails. +See https://github.com/sphinx-gallery/sphinx-gallery/issues/990 +*/ +article.pytorch-article .sphx-glr-thumbnails .sphx-glr-thumbcontainer { + width: unset; + margin-right: 0; + margin-left: 0; +} +article.pytorch-article div.section div.wy-table-responsive tbody td { + width: 50%; +} diff --git a/training/docs/source/_static/img/card-background.svg b/training/docs/source/_static/img/card-background.svg new file mode 100644 index 0000000..773515d --- /dev/null +++ b/training/docs/source/_static/img/card-background.svg @@ -0,0 +1,13 @@ + + + + + + + \ No newline at end of file diff --git a/training/docs/source/_static/img/generic-pytorch-logo.png b/training/docs/source/_static/img/generic-pytorch-logo.png new file mode 100644 index 0000000..426a14d Binary files /dev/null and b/training/docs/source/_static/img/generic-pytorch-logo.png differ diff --git a/training/docs/source/_static/img/lora_diagram.png b/training/docs/source/_static/img/lora_diagram.png new file mode 100644 index 0000000..605c33c Binary files /dev/null and b/training/docs/source/_static/img/lora_diagram.png differ diff --git a/training/docs/source/_static/img/lora_experiment_loss_curves.png b/training/docs/source/_static/img/lora_experiment_loss_curves.png new file mode 100644 index 0000000..82857ae Binary files /dev/null and b/training/docs/source/_static/img/lora_experiment_loss_curves.png differ diff --git a/training/docs/source/_static/img/pytorch-logo-dark.png b/training/docs/source/_static/img/pytorch-logo-dark.png new file mode 100644 index 0000000..0288a56 Binary files /dev/null and b/training/docs/source/_static/img/pytorch-logo-dark.png differ diff --git a/training/docs/source/_static/img/pytorch-logo-dark.svg b/training/docs/source/_static/img/pytorch-logo-dark.svg new file mode 100644 index 0000000..717a3ce --- /dev/null +++ b/training/docs/source/_static/img/pytorch-logo-dark.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + diff --git a/training/docs/source/_static/img/pytorch-logo-flame.png b/training/docs/source/_static/img/pytorch-logo-flame.png new file mode 100644 index 0000000..370633f Binary files /dev/null and b/training/docs/source/_static/img/pytorch-logo-flame.png differ diff --git a/training/docs/source/_static/img/pytorch-logo-flame.svg b/training/docs/source/_static/img/pytorch-logo-flame.svg new file mode 100644 index 0000000..5f2fb76 --- /dev/null +++ b/training/docs/source/_static/img/pytorch-logo-flame.svg @@ -0,0 +1,33 @@ + +image/svg+xml diff --git a/training/docs/source/_static/img/qat_diagram.png b/training/docs/source/_static/img/qat_diagram.png new file mode 100644 index 0000000..3d990e2 Binary files /dev/null and b/training/docs/source/_static/img/qat_diagram.png differ diff --git a/training/docs/source/_static/img/qlora_exp.png b/training/docs/source/_static/img/qlora_exp.png new file mode 100644 index 0000000..f75828f Binary files /dev/null and b/training/docs/source/_static/img/qlora_exp.png differ diff --git a/training/docs/source/_static/img/torchtune_workspace.png b/training/docs/source/_static/img/torchtune_workspace.png new file mode 100644 index 0000000..4a94a53 Binary files /dev/null and b/training/docs/source/_static/img/torchtune_workspace.png differ diff --git a/training/docs/source/_templates/autosummary/class.rst b/training/docs/source/_templates/autosummary/class.rst new file mode 100644 index 0000000..eeb823a --- /dev/null +++ b/training/docs/source/_templates/autosummary/class.rst @@ -0,0 +1,9 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: diff --git a/training/docs/source/_templates/autosummary/function.rst b/training/docs/source/_templates/autosummary/function.rst new file mode 100644 index 0000000..72abc4f --- /dev/null +++ b/training/docs/source/_templates/autosummary/function.rst @@ -0,0 +1,8 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autofunction:: {{ name }} diff --git a/training/docs/source/_templates/layout.html b/training/docs/source/_templates/layout.html new file mode 100644 index 0000000..6dbc189 --- /dev/null +++ b/training/docs/source/_templates/layout.html @@ -0,0 +1,21 @@ +{% extends "!layout.html" %} + +{% block sidebartitle %} + + {% include "searchbox.html" %} +{% endblock %} + + +{% block footer %} + + + + +{% endblock %} diff --git a/training/docs/source/api_ref_config.rst b/training/docs/source/api_ref_config.rst new file mode 100644 index 0000000..4a19461 --- /dev/null +++ b/training/docs/source/api_ref_config.rst @@ -0,0 +1,16 @@ +.. _config: + +================ +torchtune.config +================ + +.. currentmodule:: torchtune.config + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + instantiate + parse + validate + log_config diff --git a/training/docs/source/api_ref_data.rst b/training/docs/source/api_ref_data.rst new file mode 100644 index 0000000..31b56ff --- /dev/null +++ b/training/docs/source/api_ref_data.rst @@ -0,0 +1,80 @@ +.. _data: + +============== +torchtune.data +============== + +.. currentmodule:: torchtune.data + +.. _chat_formats: + +Text templates +-------------- + +Templates for instruct prompts and chat prompts. Includes some specific formatting for difference datasets +and models. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + InstructTemplate + AlpacaInstructTemplate + GrammarErrorCorrectionTemplate + SummarizeTemplate + StackExchangedPairedTemplate + PromptTemplate + PromptTemplateInterface + ChatMLTemplate + + ChatFormat + ChatMLFormat + Llama2ChatFormat + MistralChatFormat + +Types +----- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + Message + Role + +Converters +---------- + +Converts data from common JSON formats into a torchtune :class:`Message`. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + get_sharegpt_messages + get_openai_messages + +Message transforms +------------------ + +Converts data from common schema and conversation JSON formats into a list of torchtune :class:`Message`. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + InputOutputToMessages + ShareGPTToMessages + JSONToMessages + +Helper functions +---------------- + +Miscellaneous helper functions used in modifying data. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + validate_messages + truncate diff --git a/training/docs/source/api_ref_datasets.rst b/training/docs/source/api_ref_datasets.rst new file mode 100644 index 0000000..e682aa6 --- /dev/null +++ b/training/docs/source/api_ref_datasets.rst @@ -0,0 +1,59 @@ +.. _datasets: + +================== +torchtune.datasets +================== + +.. currentmodule:: torchtune.datasets + +For a detailed general usage guide, please see our :ref:`datasets tutorial `. + + +Example datasets +---------------- + +torchtune supports several widely used datasets to help quickly bootstrap your fine-tuning. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + alpaca_dataset + alpaca_cleaned_dataset + grammar_dataset + samsum_dataset + slimorca_dataset + stack_exchanged_paired_dataset + cnn_dailymail_articles_dataset + wikitext_dataset + +Generic dataset builders +------------------------ + +torchtune also supports generic dataset builders for common formats like chat models and instruct models. +These are especially useful for specifying from a YAML config. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + instruct_dataset + chat_dataset + text_completion_dataset + +Generic dataset classes +----------------------- + +Class representations for the above dataset builders. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + InstructDataset + ChatDataset + TextCompletionDataset + ConcatDataset + PackedDataset + PreferenceDataset + SFTDataset diff --git a/training/docs/source/api_ref_models.rst b/training/docs/source/api_ref_models.rst new file mode 100644 index 0000000..4d487ba --- /dev/null +++ b/training/docs/source/api_ref_models.rst @@ -0,0 +1,266 @@ +.. _models: + +================ +torchtune.models +================ + +.. currentmodule:: torchtune.models + +llama3 & llama3.1 +----------------- + +All models from the `Llama3 family `_. + +Request Access on `Hugging Face `__. + +To download the Llama3-8B-Instruct model: + +.. code-block:: bash + + tune download meta-llama/Meta-Llama-3-8B-Instruct --hf-token + +To download the Llama3-70B-Instruct model: + +.. code-block:: bash + + tune download meta-llama/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token + +To download the Llama3.1 weights of the above models, you can instead download from `Meta-Llama-3.1-8B-Instruct` +or `Meta-Llama-3.1-70B-Instruct`. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + llama3.llama3 + llama3.lora_llama3 + llama3.llama3_8b + llama3.lora_llama3_8b + llama3.qlora_llama3_8b + llama3.llama3_70b + llama3.lora_llama3_70b + llama3.qlora_llama3_70b + llama3.llama3_tokenizer + llama3.Llama3Tokenizer + + | + + llama3_1.llama3_1 + llama3_1.lora_llama3_1 + llama3_1.llama3_1_8b + llama3_1.lora_llama3_1_8b + llama3_1.qlora_llama3_1_8b + llama3_1.llama3_1_70b + llama3_1.lora_llama3_1_70b + llama3_1.qlora_llama3_1_70b + + +.. note:: + + The Llama3.1 tokenizer reuses the `llama3.llama3_tokenizer` builder class. + +llama2 +------ + +All models from the `Llama2 family `_. + +Request Access on `Hugging Face `__. + +To download the Llama2-7B model: + +.. code-block:: bash + + tune download meta-llama/Llama-2-7b-hf --hf-token + +To download the Llama2-13B model: + +.. code-block:: bash + + tune download meta-llama/Llama-2-13b-hf --hf-token + +To download the Llama2-70B model: + +.. code-block:: bash + + tune download meta-llama/Llama-2-70b-hf --hf-token + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + llama2.llama2 + llama2.lora_llama2 + llama2.llama2_7b + llama2.lora_llama2_7b + llama2.qlora_llama2_7b + llama2.llama2_13b + llama2.lora_llama2_13b + llama2.qlora_llama2_13b + llama2.llama2_70b + llama2.lora_llama2_70b + llama2.qlora_llama2_70b + llama2.llama2_tokenizer + llama2.Llama2Tokenizer + llama2.llama2_reward_7b + llama2.lora_llama2_reward_7b + llama2.qlora_llama2_reward_7b + llama2.Llama2ChatTemplate + + +code llama +---------- + +Models from the `Code Llama family `_. + +Request Access on `Hugging Face `__. + +To download the CodeLlama-7B model: + +.. code-block:: bash + + tune download codellama/CodeLlama-7b-hf --hf-token + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + code_llama2.code_llama2_7b + code_llama2.lora_code_llama2_7b + code_llama2.qlora_code_llama2_7b + code_llama2.code_llama2_13b + code_llama2.lora_code_llama2_13b + code_llama2.qlora_code_llama2_13b + code_llama2.code_llama2_70b + code_llama2.lora_code_llama2_70b + code_llama2.qlora_code_llama2_70b + +qwen-2 +------ + +Models of size 0.5B, 1.5B, and 7B from the `Qwen2 family `_. + +To download the Qwen2 1.5B model, for example: + +.. code-block:: bash + + tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + qwen2.qwen2 + qwen2.lora_qwen2 + qwen2.qwen2_7b + qwen2.qwen2_0_5b + qwen2.qwen2_1_5b + qwen2.lora_qwen2_7b + qwen2.lora_qwen2_0_5b + qwen2.lora_qwen2_1_5b + qwen2.qwen2_tokenizer + qwen2.Qwen2Tokenizer + +phi-3 +----- + +Models from the `Phi-3 mini family `_. + +To download the Phi-3 Mini 4k instruct model: + +.. code-block:: bash + + tune download microsoft/Phi-3-mini-4k-instruct --ignore-patterns None --hf-token + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + phi3.phi3 + phi3.lora_phi3 + phi3.phi3_mini + phi3.lora_phi3_mini + phi3.qlora_phi3_mini + phi3.phi3_mini_tokenizer + phi3.Phi3MiniTokenizer + + +mistral +------- + +All models from `Mistral AI family `_. + +Request Access on `Hugging Face `__. + +To download the Mistral 7B v0.1 model: + +.. code-block:: bash + + tune download mistralai/Mistral-7B-v0.1 --hf-token + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + mistral.mistral + mistral.lora_mistral + mistral.mistral_classifier + mistral.lora_mistral_classifier + mistral.mistral_7b + mistral.lora_mistral_7b + mistral.qlora_mistral_7b + mistral.mistral_reward_7b + mistral.lora_mistral_reward_7b + mistral.qlora_mistral_reward_7b + mistral.mistral_tokenizer + mistral.MistralTokenizer + mistral.MistralChatTemplate + + +gemma +----- + +Models of size 2B and 7B from the `Gemma family `_. + +Request Access on `Hugging Face `__. + +To download the Gemma 2B model: + +.. code-block:: bash + + tune download google/gemma-2b --ignore-patterns None --hf-token + +To download the Gemma 7B model: + +.. code-block:: bash + + tune download google/gemma-7b --ignore-patterns "gemma-7b.gguf" --hf-token + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + gemma.gemma + gemma.lora_gemma + gemma.gemma_2b + gemma.lora_gemma_2b + gemma.qlora_gemma_2b + gemma.gemma_7b + gemma.lora_gemma_7b + gemma.qlora_gemma_7b + gemma.gemma_tokenizer + gemma.GemmaTokenizer + + +clip +----- + +Vision components to support multimodality using `CLIP encoder `_. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + clip.clip_vision_encoder + clip.TokenPositionalEmbedding + clip.TiledTokenPositionalEmbedding + clip.TilePositionalEmbedding diff --git a/training/docs/source/api_ref_modules.rst b/training/docs/source/api_ref_modules.rst new file mode 100644 index 0000000..df7a531 --- /dev/null +++ b/training/docs/source/api_ref_modules.rst @@ -0,0 +1,119 @@ +================= +torchtune.modules +================= + +.. currentmodule:: torchtune.modules + +Modeling Components and Building Blocks +--------------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + CausalSelfAttention + FeedForward + KVCache + get_cosine_schedule_with_warmup + RotaryPositionalEmbeddings + RMSNorm + Fp32LayerNorm + TransformerDecoderLayer + TransformerDecoder + VisionTransformer + +Base Tokenizers +--------------- +Base tokenizers are tokenizer models that perform the direct encoding of text +into token IDs and decoding of token IDs into text. These are typically `byte pair +encodings `_ that underlie the +model specific tokenizers. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + tokenizers.SentencePieceBaseTokenizer + tokenizers.TikTokenBaseTokenizer + tokenizers.ModelTokenizer + tokenizers.BaseTokenizer + +Tokenizer Utilities +------------------- +These are helper methods that can be used by any tokenizer. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + tokenizers.tokenize_messages_no_special_tokens + tokenizers.parse_hf_tokenizer_json + + +PEFT Components +--------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + peft.LoRALinear + peft.AdapterModule + peft.get_adapter_params + peft.set_trainable_params + peft.validate_missing_and_unexpected_for_lora + peft.validate_state_dict_for_lora + peft.disable_adapter + +Module Utilities +------------------ +These are utilities that are common to and can be used by all modules. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + common_utils.reparametrize_as_dtype_state_dict_post_hook + +Loss +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + loss.PPOLoss + loss.DPOLoss + loss.RSOLoss + loss.IPOLoss + loss.SimPOLoss + + +Vision Transforms +------------------ +Functions used for preprocessing images. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + transforms.Transform + transforms.get_canvas_best_fit + transforms.resize_with_pad + transforms.tile_crop + transforms.find_supported_resolutions + transforms.VisionCrossAttentionMask + +Reinforcement Learning From Human Feedback (RLHF) +-------------------------------------------------- +Components for RLHF algorithms like PPO. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + rlhf.estimate_advantages + rlhf.get_rewards_ppo + rlhf.truncate_sequence_at_first_stop_token + rlhf.left_padded_collate + rlhf.padded_collate_dpo diff --git a/training/docs/source/api_ref_utilities.rst b/training/docs/source/api_ref_utilities.rst new file mode 100644 index 0000000..4d64b9b --- /dev/null +++ b/training/docs/source/api_ref_utilities.rst @@ -0,0 +1,134 @@ +================= +torchtune.utils +================= + +.. currentmodule:: torchtune.utils + + +.. _checkpointing_label: + +Checkpointing +------------- + +torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of +checkpointing, please see the :ref:`checkpointing deep-dive `. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + FullModelHFCheckpointer + FullModelMetaCheckpointer + FullModelTorchTuneCheckpointer + ModelType + +.. _dist_label: + +Distributed +----------- + +Utilities for enabling and working with distributed training. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + FSDPPolicyType + init_distributed + is_distributed + get_world_size_and_rank + get_full_finetune_fsdp_wrap_policy + lora_fsdp_wrap_policy + +.. _mp_label: + +Reduced Precision +------------------ + +Utilities for working in a reduced precision setting. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + get_dtype + set_default_dtype + validate_expected_param_dtype + get_quantizer_mode + +.. _ac_label: + +Memory Management +----------------- + +Utilities to reduce memory consumption during training. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + set_activation_checkpointing + OptimizerInBackwardWrapper + create_optim_in_bwd_wrapper + register_optim_in_bwd_hooks + + +.. _perf_profiling_label: + +Performance and Profiling +------------------------- + +torchtune provides utilities to profile and debug the memory and performance +of your finetuning job. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + get_memory_stats + log_memory_stats + setup_torch_profiler + +.. _metric_logging_label: + +Metric Logging +-------------- + +Various logging utilities. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + metric_logging.WandBLogger + metric_logging.TensorBoardLogger + metric_logging.StdoutLogger + metric_logging.DiskLogger + +Data +---- + +Utilities for working with data and datasets. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + padded_collate + +.. _gen_label: + + +Miscellaneous +------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + get_logger + get_device + set_seed + generate + torch_version_ge + TuneRecipeArgumentParser diff --git a/training/docs/source/conf.py b/training/docs/source/conf.py new file mode 100644 index 0000000..f2c1b86 --- /dev/null +++ b/training/docs/source/conf.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# PyTorch documentation build configuration file, created by +# sphinx-quickstart on Fri Dec 23 13:31:47 2016. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + +import os +import sys + +import pytorch_sphinx_theme + +sys.path.append(os.path.abspath(".")) + +# -- General configuration ------------------------------------------------ + +# Required version of sphinx is set from docs/requirements.txt + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.duration", + "sphinx_tabs.tabs", + "sphinx_design", + "sphinx_gallery.gen_gallery", + "sphinx_copybutton", +] + +sphinx_gallery_conf = { + "examples_dirs": "tutorials/", # path to your sphinx-gallery examples + "gallery_dirs": "generated_examples", # path to where to save shpinx-gallery generated output + "filename_pattern": "./*.py", # any .py file in docs/source/examples will be built by sphinx-gallery + "backreferences_dir": "gen_modules/backreferences", # path to store the backreferences + "doc_module": ("torchtune",), + "remove_config_comments": True, +} + +napoleon_use_ivar = True +napoleon_numpy_docstring = False +napoleon_google_docstring = True + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = [".rst"] + +# Get TORCHTUNE_VERSION_DOCS during the build. +torchtune_version_docs = os.environ.get("TORCHTUNE_VERSION_DOCS", None) + +# Get TORCHTUNE_VERSION_DOCS during the build. +torchtune_version_docs = os.environ.get("TORCHTUNE_VERSION_DOCS", None) +print(f"torchtune_version_docs: {torchtune_version_docs}") +project = "torchtune" + +# The code below will cut version displayed in the dropdown like this: +# By default, set to "main". +# If it's a tag like refs/tags/v1.2.3-rc4 or refs/tags/v1.2.3, or a +# refs/heads/release/1.2 then +# cut to 1.2 +# the version varible is used in layout.html: https://github.com/pytorch/torchtune/blob/main/docs/source/_templates/layout.html#L29 +version = release = "main" +if torchtune_version_docs: + if torchtune_version_docs.startswith("refs/tags/v"): + version = ".".join( + torchtune_version_docs.split("/")[-1] + .split("-")[0] + .lstrip("v") + .split(".")[:2] + ) + elif torchtune_version_docs.startswith("refs/heads/release/"): + version = torchtune_version_docs.split("/")[-1] +print(f"Version: {version}") +html_title = " ".join((project, version, "documentation")) + + +# The master toctree document. +master_doc = "index" + +# General information about the project. +copyright = "2023-present, torchtune Contributors" +author = "Torch Contributors" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = "en" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "pytorch_sphinx_theme" +html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "collapse_navigation": False, + "display_version": True, + "logo_only": True, + "pytorch_project": "docs", + "navigation_with_keys": True, + "analytics_id": "GTM-T8XT4PS", +} + +html_logo = "_static/img/pytorch-logo-dark.svg" + +html_css_files = ["css/custom_torchtune.css"] + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = "PyTorchdoc" + + +autosummary_generate = True + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "PIL": ("https://pillow.readthedocs.io/en/stable/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), +} + +# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- +# See http://stackoverflow.com/a/41184353/3343043 + +from docutils import nodes +from sphinx import addnodes +from sphinx.util.docfields import TypedField + + +def patched_make_field(self, types, domain, items, **kw): + # `kw` catches `env=None` needed for newer sphinx while maintaining + # backwards compatibility when passed along further down! + + # type: (list, unicode, tuple) -> nodes.field # noqa: F821 + def handle_item(fieldarg, content): + par = nodes.paragraph() + par += addnodes.literal_strong("", fieldarg) # Patch: this line added + # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, + # addnodes.literal_strong)) + if fieldarg in types: + par += nodes.Text(" (") + # NOTE: using .pop() here to prevent a single type node to be + # inserted twice into the doctree, which leads to + # inconsistencies later when references are resolved + fieldtype = types.pop(fieldarg) + if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): + typename = "".join(n.astext() for n in fieldtype) + typename = typename.replace("int", "python:int") + typename = typename.replace("long", "python:long") + typename = typename.replace("float", "python:float") + typename = typename.replace("type", "python:type") + par.extend( + self.make_xrefs( + self.typerolename, + domain, + typename, + addnodes.literal_emphasis, + **kw, + ) + ) + else: + par += fieldtype + par += nodes.Text(")") + par += nodes.Text(" -- ") + par += content + return par + + fieldname = nodes.field_name("", self.label) + if len(items) == 1 and self.can_collapse: + fieldarg, content = items[0] + bodynode = handle_item(fieldarg, content) + else: + bodynode = self.list_type() + for fieldarg, content in items: + bodynode += nodes.list_item("", handle_item(fieldarg, content)) + fieldbody = nodes.field_body("", bodynode) + return nodes.field("", fieldname, fieldbody) + + +TypedField.make_field = patched_make_field + + +def inject_minigalleries(app, what, name, obj, options, lines): + """Inject a minigallery into a docstring. + + [This is 100% taken from torchvision] + + This avoids having to manually write the .. minigallery directive for every item we want a minigallery for, + as it would be easy to miss some. + + This callback is called after the .. auto directives (like ..autoclass) have been processed, + and modifies the lines parameter inplace to add the .. minigallery that will show which examples + are using which object. + + It's a bit hacky, but not *that* hacky when you consider that the recommended way is to do pretty much the same, + but instead with templates using autosummary (which we don't want to use): + (https://sphinx-gallery.github.io/stable/configuration.html#auto-documenting-your-api-with-links-to-examples) + + For docs on autodoc-process-docstring, see the autodoc docs: + https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html + """ + + if what in ("class", "function"): + lines.append(f".. minigallery:: {name}") + lines.append(f" :add-heading: Examples using ``{name.split('.')[-1]}``:") + # avoid heading entirely to avoid warning. As a bonud it actually renders better + lines.append(" :heading-level: 9") + lines.append("\n") + + +def setup(app): + + app.connect("autodoc-process-docstring", inject_minigalleries) + + +# Custom directives definitions to create cards on main torchtune page + +from custom_directives import CustomCardEnd, CustomCardItem, CustomCardStart +from docutils.parsers import rst + +rst.directives.register_directive("customcardstart", CustomCardStart) +rst.directives.register_directive("customcarditem", CustomCardItem) +rst.directives.register_directive("customcardend", CustomCardEnd) diff --git a/training/docs/source/custom_directives.py b/training/docs/source/custom_directives.py new file mode 100644 index 0000000..62053f2 --- /dev/null +++ b/training/docs/source/custom_directives.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib +import os +from pathlib import Path +from typing import List +from urllib.parse import quote, urlencode + +import requests +from docutils import nodes +from docutils.parsers.rst import Directive, directives +from docutils.parsers.rst.directives.images import Image +from docutils.statemachine import StringList +from sphinx.util.docutils import SphinxDirective + +_THIS_DIR = Path(__file__).parent + +# Color palette from PyTorch Developer Day 2021 Presentation Template +YELLOW = "F9DB78" +GREEN = "70AD47" +BLUE = "00B0F0" +PINK = "FF71DA" +ORANGE = "FF8300" +TEAL = "00E5D1" +GRAY = "7F7F7F" + + +def _get_cache_path(key, ext): + filename = f"{hashlib.sha256(key).hexdigest()}{ext}" + cache_dir = _THIS_DIR / "gen_images" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / filename + + +def _download(url, path): + response = requests.get(url) + response.raise_for_status() + with open(path, "wb") as file: + file.write(response.content) + + +def _fetch_image(url): + path = _get_cache_path(url.encode("utf-8"), ext=".svg") + if not path.exists(): + _download(url, path) + return os.sep + str(path.relative_to(_THIS_DIR)) + + +def _get_relpath(target, base): + target = os.sep + target + base = os.sep + base + target_path, filename = os.path.split(target) + rel_path = os.path.relpath(target_path, os.path.dirname(base)) + return os.path.normpath(os.path.join(rel_path, filename)) + + +class BaseShield(Image, SphinxDirective): + def run(self, params, alt, section) -> List[nodes.Node]: + url = f"https://img.shields.io/static/v1?{urlencode(params, quote_via=quote)}" + path = _fetch_image(url) + self.arguments = [path] + self.options["alt"] = alt + if "class" not in self.options: + self.options["class"] = [] + self.options["class"].append("shield-badge") + target = _get_relpath("supported_features.html", self.env.docname) + self.options["target"] = f"{target}#{section}" + return super().run() + + +_CARDLIST_START = """ +.. raw:: html + +
+ + +
+ +
+
+
+""" + +_CARD_TEMPLATE = """ +.. raw:: html + + +""" + +_CARDLIST_END = """ +.. raw:: html + +
+ +
+
+
+""" + + +class CustomCardStart(Directive): + def run(self): + para = nodes.paragraph() + self.state.nested_parse( + StringList(_CARDLIST_START.split("\n")), self.content_offset, para + ) + return [para] + + +class CustomCardItem(Directive): + option_spec = { + "header": directives.unchanged, + "image": directives.unchanged, + "link": directives.unchanged, + "card_description": directives.unchanged, + "tags": directives.unchanged, + } + + def run(self): + for key in ["header", "card_description", "link"]: + if key not in self.options: + raise ValueError(f"Key: `{key}` is missing") + + header = self.options["header"] + link = self.options["link"] + card_description = self.options["card_description"] + tags = self.options.get("tags", "") + + if "image" in self.options: + image = "" + else: + image = "_static/img/thumbnails/default.png" + + card_rst = _CARD_TEMPLATE.format( + header=header, + image=image, + link=link, + card_description=card_description, + tags=tags, + ) + card_list = StringList(card_rst.split("\n")) + card = nodes.paragraph() + self.state.nested_parse(card_list, self.content_offset, card) + return [card] + + +class CustomCardEnd(Directive): + def run(self): + para = nodes.paragraph() + self.state.nested_parse( + StringList(_CARDLIST_END.split("\n")), self.content_offset, para + ) + return [para] diff --git a/training/docs/source/deep_dives/README.txt b/training/docs/source/deep_dives/README.txt new file mode 100644 index 0000000..e69de29 diff --git a/training/docs/source/deep_dives/checkpointer.rst b/training/docs/source/deep_dives/checkpointer.rst new file mode 100644 index 0000000..012e9a5 --- /dev/null +++ b/training/docs/source/deep_dives/checkpointer.rst @@ -0,0 +1,476 @@ +.. _understand_checkpointer: + +========================== +Checkpointing in torchtune +========================== + +This deep-dive will walk you through the design and behavior of the checkpointer and +associated utilities. + +.. grid:: 1 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What this deep-dive will cover: + + * Checkpointer design for torchtune + * Checkpoint formats and how we handle them + * Checkpointing scenarios: Intermediate vs Final and LoRA vs Full-finetune + + +Overview +-------- + +torchtune checkpointers are designed to be composable components which can be plugged +into any recipe - training, evaluation or generation. Each checkpointer supports a +set of models and scenarios making these easy to understand, debug and extend. + +Before we dive into the checkpointer in torchtune, let's define some concepts. + +| + +Checkpoint Format +^^^^^^^^^^^^^^^^^ + +In this deep-dive, we'll talk about different checkpoint formats and how torchtune handles them. +Let's take a close look at these different formats. + +Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored +in files on disk. Each weight is associated with a string key that identifies it in the state dict. +If the string identifier of the keys in the stored checkpoints don't match up +exactly with those in the model definition, you'll either run into explicit errors (loading the +state dict will raise an exception) or worse - silent errors (loading will succeed but training or +inference will not work as expected). In addition to the keys lining up, you also need the shapes +of the weights (values in the state_dict) to match up exactly with those expected by the model +definition. + +Let's look at the two popular formats for Llama2. + +**Meta Format** + +This is the format supported by the official Llama2 implementation. When you download the Llama2 7B model +from the `meta-llama website `_, you'll get access to a single +``.pth`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load`` + +.. code-block:: python + + >>> import torch + >>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu') + >>> # inspect the keys and the shapes of the associated tensors + >>> for key, value in state_dict.items(): + >>> print(f'{key}: {value.shape}') + + tok_embeddings.weight: torch.Size([32000, 4096]) + ... + ... + >>> print(len(state_dict.keys())) + 292 + +The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. The +model definition for this state_dict expects an embedding layer with ``32000`` tokens each having a +embedding with dim of ``4096``. + + +**HF Format** + +This is the most popular format within the Hugging Face Model Hub and is +the default format in every torchtune config. This is also the format you get when you download the +llama2 model from the `Llama-2-7b-hf `_ repo. + +The first big difference is that the state_dict is split across two ``.bin`` files. To correctly +load the checkpoint, you'll need to piece these files together. Let's inspect one of the files. + +.. code-block:: python + + >>> import torch + >>> state_dict = torch.load('pytorch_model-00001-of-00002.bin', mmap=True, weights_only=True, map_location='cpu') + >>> # inspect the keys and the shapes of the associated tensors + >>> for key, value in state_dict.items(): + >>> print(f'{key}: {value.shape}') + + model.embed_tokens.weight: torch.Size([32000, 4096]) + ... + ... + >>> print(len(state_dict.keys())) + 241 + +Not only does the state_dict contain fewer keys (expected since this is one of two files), but +the embedding table is called ``model.embed_tokens`` instead of ``tok_embeddings``. This mismatch +in names will cause an exception when you try to load the state_dict. The size of this layer is the +same between the two, which is as expected. + +| + +As you can see, if you're not careful you'll likely end up making a number of errors just during +checkpoint load and save. The torchtune checkpointer makes this less error-prone by managing state dicts +for you. torchtune is designed to be "state-dict invariant". + +- When loading, torchtune accepts checkpoints from multiple sources in multiple formats. + You don't have to worry about explicitly converting checkpoints every time you run a recipe. + +- When saving, torchtune produces checkpoints in the same format as the source. This includes + converting the state_dict back into the original form and splitting the keys and weights + across the same number of files. + +One big advantage of being "state-dict invariant" is that you should be able to use +fine-tuned checkpoints from torchtune with any post-training tool (quantization, eval, inference) +which supports the source format, without any code changes OR conversion scripts. This is one of the +ways in which torchtune interoperates with the surrounding ecosystem. + +.. note:: + + To be state-dict "invariant" in this way, the ``load_checkpoint`` and ``save_checkpoint`` methods of each checkpointer + make use of weight converters which correctly map weights between checkpoint formats. For example, when loading weights + from Hugging Face, we apply a permutation to certain weights on load and save to ensure checkpoints behave exactly the same. + To further illustrate this, the Llama family of models uses a + `generic weight converter function `_ + whilst some other models like Phi3 have their own `conversion functions `_ + which can be found within their model folders. + +| + +Handling different Checkpoint Formats +------------------------------------- + +torchtune supports three different +:ref:`checkpointers`, +each of which supports a different checkpoint format. + + +:class:`HFCheckpointer ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This checkpointer reads and writes checkpoints in a format which is compatible with the transformers +framework from Hugging Face. As mentioned above, this is the most popular format within the Hugging Face +Model Hub and is the default format in every torchtune config. + +For this checkpointer to work correctly, we assume that ``checkpoint_dir`` contains the necessary checkpoint +and json files. The easiest way to make sure everything works correctly is to use the following flow: + +- Download the model from the HF repo using tune download. By default, this will ignore the "safetensors" + files. + + | + + .. code-block:: bash + + tune download meta-llama/Llama-2-7b-hf \ + --output-dir \ + --hf-token + +- Use ``output_dir`` specified here as the ``checkpoint_dir`` argument for the checkpointer. + +| + +The following snippet explains how the HFCheckpointer is setup in torchtune config files. + +.. code-block:: yaml + + checkpointer: + + # checkpointer to use + _component_: torchtune.utils.FullModelHFCheckpointer + + # directory with the checkpoint files + # this should match the output_dir above + checkpoint_dir: + + # checkpoint files. For the llama2-7b-hf model we have + # 2 .bin files. The checkpointer takes care of sorting + # by id and so the order here does not matter + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin, + ] + + # if we're restarting a previous run, we need to specify + # the file with the checkpoint state. More on this in the + # next section + recipe_checkpoint: null + + # dir for saving the output checkpoints. Usually set + # to be the same as checkpoint_dir + output_dir: + + # model_type which specifies how to convert the state_dict + # into a format which torchtune understands + model_type: LLAMA2 + + # set to True if restarting training + resume_from_checkpoint: False + +.. note:: + Checkpoint conversion to and from HF's format requires access to model params which are + read directly from the ``config.json`` file. This helps ensure we either load the weights + correctly or error out in case of discrepancy between the HF checkpoint file and torchtune's + model implementations. This json file is downloaded from the hub along with the model checkpoints. + +| + +:class:`MetaCheckpointer ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This checkpointer reads and writes checkpoints in a format which is compatible with the original meta-llama +github repository. + + +For this checkpointer to work correctly, we assume that ``checkpoint_dir`` contains the necessary checkpoint +and json files. The easiest way to make sure everything works correctly is to use the following flow: + +- Download the model from the HF repo using tune download. By default, this will ignore the "safetensors" + files. + + | + + .. code-block:: bash + + tune download meta-llama/Llama-2-7b \ + --output-dir \ + --hf-token + +- Use ``output_dir`` above as the ``checkpoint_dir`` for the checkpointer. + +| + +The following snippet explains how the MetaCheckpointer is setup in torchtune config files. + +.. code-block:: yaml + + checkpointer: + + # checkpointer to use + _component_: torchtune.utils.FullModelMetaCheckpointer + + # directory with the checkpoint files + # this should match the output_dir above + checkpoint_dir: + + # checkpoint files. For the llama2-7b model we have + # a single .pth file + checkpoint_files: [consolidated.00.pth] + + # if we're restarting a previous run, we need to specify + # the file with the checkpoint state. More on this in the + # next section + recipe_checkpoint: null + + # dir for saving the output checkpoints. Usually set + # to be the same as checkpoint_dir + output_dir: + + # model_type which specifies how to convert the state_dict + # into a format which torchtune understands + model_type: LLAMA2 + + # set to True if restarting training + resume_from_checkpoint: False + +| + +:class:`TorchTuneCheckpointer ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This checkpointer reads and writes checkpoints in a format that is compatible with torchtune's +model definition. This does not perform any state_dict conversions and is currently used either +for testing or for loading quantized models for generation. + +| + + +Intermediate vs Final Checkpoints +--------------------------------- + +torchtune Checkpointers support two checkpointing scenarios: + +**End-of-training Checkpointing** + +The model weights at the end of a completed training +run are written out to file. The checkpointer ensures that the output checkpoint +files have the same keys as the input checkpoint file used to begin training. The +checkpointer also ensures that the keys are partitioned across the same number of +files as the original checkpoint. The output state dict has the following +standard format: + + .. code-block:: python + + { + "key_1": weight_1, + "key_2": weight_2, + ... + } + + +**Mid-training Chekpointing**. + +If checkpointing in the middle of training, the output checkpoint needs to store additional +information to ensure that subsequent training runs can be correctly restarted. In addition to +the model checkpoint files, we output a ``recipe_state.pt`` file for intermediate +checkpoints. These are currently output at the end of each epoch, and contain information +such as optimizer state, number of epochs completed etc. + +To prevent us from flooding ``output_dir`` with checkpoint files, the recipe state is +overwritten at the end of each epoch. + +The output state dicts have the following formats: + + .. code-block:: python + + Model: + { + "key_1": weight_1, + "key_2": weight_2, + ... + } + + Recipe State: + { + "optimizer": ..., + "epoch": ..., + ... + } + +To restart from a previous checkpoint file, you'll need to make the following changes +to the config file + +.. code-block:: yaml + + checkpointer: + + # checkpointer to use + _component_: torchtune.utils.FullModelHFCheckpointer + + checkpoint_dir: + + # checkpoint files. Note that you will need to update this + # section of the config with the intermediate checkpoint files + checkpoint_files: [ + hf_model_0001_0.pt, + hf_model_0002_0.pt, + ] + + # if we're restarting a previous run, we need to specify + # the file with the checkpoint state + recipe_checkpoint: recipe_state.pt + + # dir for saving the output checkpoints. Usually set + # to be the same as checkpoint_dir + output_dir: + + # model_type which specifies how to convert the state_dict + # into a format which torchtune understands + model_type: LLAMA2 + + # set to True if restarting training + resume_from_checkpoint: True + + +Checkpointing for LoRA +---------------------- + +In torchtune, we output both the adapter weights and the full model "merged" weights +for LoRA. The "merged" checkpoint can be used just like you would use the source +checkpoint with any post-training tools. For more details, take a look at our +:ref:`LoRA Finetuning Tutorial `.Additionally, by setting the option "save_adapter_weights_only" to True when saving a checkpoint, you can choose to save only the adapter weights. + +The primary difference between the two use cases is when you want to resume training +from a checkpoint. In this case, the checkpointer needs access to both the initial frozen +base model weights as well as the learnt adapter weights. The config for this scenario +looks something like this: + + +.. code-block:: yaml + + checkpointer: + + # checkpointer to use + _component_: torchtune.utils.FullModelHFCheckpointer + + # directory with the checkpoint files + # this should match the output_dir above + checkpoint_dir: + + # checkpoint files. This is the ORIGINAL frozen checkpoint + # and NOT the merged checkpoint output during training + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin, + ] + + # this refers to the adapter weights learnt during training + adapter_checkpoint: adapter_0.pt + + # the file with the checkpoint state + recipe_checkpoint: recipe_state.pt + + # dir for saving the output checkpoints. Usually set + # to be the same as checkpoint_dir + output_dir: + + # model_type which specifies how to convert the state_dict + # into a format which torchtune understands + model_type: LLAMA2 + + # set to True if restarting training + resume_from_checkpoint: True + + # Set to True to save only the adapter weights + save_adapter_weights_only: False + +| + +Putting this all together +------------------------- + +Let's now put all of this knowledge together! We'll load some checkpoints, +create some models and run a simple forward. + +For this section we'll use the Llama2 13B model in HF format. + +.. code-block:: python + + import torch + from torchtune.utils import FullModelHFCheckpointer, ModelType + from torchtune.models.llama2 import llama2_13b + + # Set the right directory and files + checkpoint_dir = 'Llama-2-13b-hf/' + pytorch_files = [ + 'pytorch_model-00001-of-00003.bin', + 'pytorch_model-00002-of-00003.bin', + 'pytorch_model-00003-of-00003.bin' + ] + + # Set up the checkpointer and load state dict + checkpointer = FullModelHFCheckpointer( + checkpoint_dir=checkpoint_dir, + checkpoint_files=pytorch_files, + output_dir=checkpoint_dir, + model_type=ModelType.LLAMA2 + ) + torchtune_sd = checkpointer.load_checkpoint() + + # Setup the model and the input + model = llama2_13b() + + # Model weights are stored with the key="model" + model.load_state_dict(torchtune_sd["model"]) + + + # We have 32000 vocab tokens; lets generate an input with 70 tokens + x = torch.randint(0, 32000, (1, 70)) + + with torch.no_grad(): + model(x) + + tensor([[[ -6.3989, -9.0531, 3.2375, ..., -5.2822, -4.4872, -5.7469], + [ -8.6737, -11.0023, 6.8235, ..., -2.6819, -4.2424, -4.0109], + [ -4.6915, -7.3618, 4.1628, ..., -2.8594, -2.5857, -3.1151], + ..., + [ -7.7808, -8.2322, 2.8850, ..., -1.9604, -4.7624, -1.6040], + [ -7.3159, -8.5849, 1.8039, ..., -0.9322, -5.2010, -1.6824], + [ -7.8929, -8.8465, 3.3794, ..., -1.3500, -4.6145, -2.5931]]]) + + +You can do this with any model supported by torchtune. You can find a full list +of models and model builders :ref:`here `. + +We hope this deep-dive provided a deeper insight into the checkpointer and +associated utilities in torchtune. Happy tuning! diff --git a/training/docs/source/deep_dives/configs.rst b/training/docs/source/deep_dives/configs.rst new file mode 100644 index 0000000..c6b9544 --- /dev/null +++ b/training/docs/source/deep_dives/configs.rst @@ -0,0 +1,271 @@ +.. _config_tutorial_label: + +================= +All About Configs +================= + +This deep-dive will guide you through writing configs for running recipes. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What this deep-dive will cover + + * How to write a YAML config and run a recipe with it + * How to use :code:`instantiate` and :code:`parse` APIs + * How to effectively use configs and CLI overrides for running recipes + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with the :ref:`overview of torchtune` + * Make sure to :ref:`install torchtune` + * Understand the :ref:`fundamentals of recipes` + + +Where do parameters live? +------------------------- + +There are two primary entry points for you to configure parameters: **configs** and +**CLI overrides**. Configs are YAML files that define all the +parameters needed to run a recipe within a single location. They are the single +source of truth for reproducing a run. The config parameters can be overridden on the +command-line using :code:`tune` for quick changes and experimentation without +modifying the config. + + +Writing configs +--------------- +Configs serve as the primary entry point for running recipes in torchtune. They are +expected to be YAML files and they simply list out values for parameters you want to define +for a particular run. + +.. code-block:: yaml + + seed: null + shuffle: True + device: cuda + dtype: fp32 + enable_fsdp: True + ... + +Configuring components using :func:`instantiate` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Many fields will require specifying torchtune objects with associated keyword +arguments as parameters. Models, datasets, optimizers, and loss functions are +common examples of this. You can easily do this using the :code:`_component_` +subfield. In :code:`_component_`, you need to specify the dotpath of the object +you wish to instantiate in the recipe. The dotpath is the exact path you would use +to import the object normally in a Python file. For example, to specify the +:class:`~torchtune.datasets.alpaca_dataset` in your config with custom +arguments: + +.. code-block:: yaml + + dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: False + +Here, we are changing the default value for :code:`train_on_input` from :code:`True` +to :code:`False`. + +Once you've specified the :code:`_component_` in your config, you can create an +instance of the specified object in your recipe's setup like so: + +.. code-block:: python + + from torchtune import config + + # Access the dataset field and create the object instance + dataset = config.instantiate(cfg.dataset) + +This will automatically use any keyword arguments specified in the fields under +:code:`dataset`. + +As written, the preceding example will actually throw an error. If you look at the method for :class:`~torchtune.datasets.alpaca_dataset`, +you'll notice that we're missing a required positional argument, the tokenizer. +Since this is another configurable torchtune object, let's understand how to handle +this by taking a look at the :func:`~torchtune.config.instantiate` API. + +.. code-block:: python + + def instantiate( + config: DictConfig, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) + +:func:`~torchtune.config.instantiate` also accepts positional arguments +and keyword arguments and automatically uses that with the config when creating +the object. This means we can not only pass in the tokenizer, but also add additional +keyword arguments not specified in the config if we'd like: + +.. code-block:: yaml + + # Tokenizer is needed for the dataset, configure it first + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/tokenizer.model + + dataset: + _component_: torchtune.datasets.alpaca_dataset + +.. code-block:: python + + # Note the API of the tokenizer we specified - we need to pass in a path + def llama2_tokenizer(path: str) -> Llama2Tokenizer: + + # Note the API of the dataset we specified - we need to pass in a model tokenizer + # and any optional keyword arguments + def alpaca_dataset( + tokenizer: ModelTokenizer, + train_on_input: bool = True, + max_seq_len: int = 512, + ) -> InstructDataset: + + from torchtune import config + + # Since we've already specified the path in the config, we don't need to pass + # it in + tokenizer = config.instantiate(cfg.tokenizer) + # We pass in the instantiated tokenizer as the first required argument, then + # we change an optional keyword argument + dataset = config.instantiate( + cfg.dataset, + tokenizer, + train_on_input=False, + ) + +Note that additional keyword arguments will overwrite any duplicated keys in the +config. + +Referencing other config fields with interpolations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Sometimes you need to use the same value more than once for multiple fields. You +can use *interpolations* to reference another field, and :func:`~torchtune.config.instantiate` +will automatically resolve it for you. + +.. code-block:: yaml + + output_dir: /tmp/alpaca-llama2-finetune + metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} + +Validating your config +^^^^^^^^^^^^^^^^^^^^^^ +We provide a convenient CLI utility, :ref:`tune validate`, to quickly verify that +your config is well-formed and all components can be instantiated properly. You +can also pass in overrides if you want to test out the exact commands you will run +your experiments with. If any parameters are not well-formed, :ref:`tune validate` +will list out all the locations where an error was found. + +.. code-block:: bash + + tune cp llama2/7B_lora_single_device ./my_config.yaml + tune validate ./my_config.yaml + +Best practices for writing configs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Let's discuss some guidelines for writing configs to get the most out of them. + +Airtight configs +"""""""""""""""" +While it may be tempting to put as much as you can in the config to give you +maximum flexibility in switching parameters for your experiments, we encourage +you to only include fields in the config that will be used or instantiated in the +recipe. This ensures full clarity on the options a recipe was run with and will +make it significantly easier to debug. + +.. code-block:: yaml + + # dont do this + alpaca_dataset: + _component_: torchtune.datasets.alpaca_dataset + slimorca_dataset: + ... + + # do this + dataset: + # change this in config or override when needed + _component_: torchtune.datasets.alpaca_dataset + +Use public APIs only +"""""""""""""""""""" +If a component you wish to specify in a config is located in a private file, use +the public dotpath in your config. These components are typically exposed in their +parent module's :code:`__init__.py` file. This way, you can guarantee the stability +of the API you are using in your config. There should be no underscores in your +component dotpath. + +.. code-block:: yaml + + # don't do this + dataset: + _component_: torchtune.datasets._alpaca.alpaca_dataset + + # do this + dataset: + _component_: torchtune.datasets.alpaca_dataset + +.. _cli_override: + +Command-line overrides +---------------------- +Configs are the primary location to collect all your parameters to run a recipe, +but sometimes you may want to quickly try different values without having to update +the config itself. To enable quick experimentation, you can specify override values +to parameters in your config via the :code:`tune` command. These should be specified +as key-value pairs :code:`k1=v1 k2=v2 ...` + +.. TODO (SalmanMohammadi) link this to the upcoming recipe docpage for the lora recipe + +For example, to run the :code:`lora_finetune_single_device` recipe with custom model and tokenizer directories, you can provide overrides: + +.. code-block:: bash + + tune run lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ + checkpointer.checkpoint_dir=/home/my_model_checkpoint \ + checkpointer.checkpoint_files=['file_1','file_2'] \ + tokenizer.path=/home/my_tokenizer_path + +Overriding components +^^^^^^^^^^^^^^^^^^^^^ +If you would like to override a class or function in the config that is instantiated +via the :code:`_component_` field, you can do so by assigning to the parameter +name directly. Any nested fields in the components can be overridden with dot notation. + +.. code-block:: yaml + + dataset: + _component_: torchtune.datasets.alpaca_dataset + +.. code-block:: bash + + # Change to slimorca_dataset and set train_on_input to True + tune run lora_finetune_single_device --config my_config.yaml \ + dataset=torchtune.datasets.slimorca_dataset dataset.train_on_input=True + +Removing config fields +^^^^^^^^^^^^^^^^^^^^^^ +You may need to remove certain parameters from the config when changing components +through overrides that require different keyword arguments. You can do so by using +the `~` flag and specify the dotpath of the config field you would like to remove. +For example, if you want to override a built-in config and use the +`bitsandbytes.optim.PagedAdamW8bit `_ +optimizer, you may need to delete parameters like ``foreach`` which are +specific to PyTorch optimizers. Note that this example requires that you have `bitsandbytes `_ +installed. + +.. code-block:: yaml + + # In configs/llama3/8B_full.yaml + optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + +.. code-block:: bash + + # Change to PagedAdamW8bit and remove fused, foreach + tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ + optimizer=bitsandbytes.optim.PagedAdamW8bit ~optimizer.foreach diff --git a/training/docs/source/deep_dives/recipe_deepdive.rst b/training/docs/source/deep_dives/recipe_deepdive.rst new file mode 100644 index 0000000..f68a4d6 --- /dev/null +++ b/training/docs/source/deep_dives/recipe_deepdive.rst @@ -0,0 +1,234 @@ +.. _recipe_deepdive: + +================= +What Are Recipes? +================= + +This deep-dive will walk you through the design of training-recipes in torchtune. + +.. grid:: 1 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What this deep-dive will cover + + * What are recipes? + * What are the core components that make up a recipe? + * How should I structure a new recipe? + + +What are Recipes? +----------------- +Recipes are the primary entry points for torchtune users. These can be thought of +as "targeted" end-to-end pipelines for training and optionally evaluating LLMs. +Each recipe implements a training method (eg: full fine-tuning) with a set of meaningful +features (eg: FSDP + Activation Checkpointing + Gradient Accumulation + Mixed Precision +training) applied to a given model family (eg: Llama2). + +As model training gets more and more complex, it becomes harder to anticipate new model +architectures and training methodologies while also reasoning about every possible trade-off +(eg: memory vs model quality). We believe a) users are best suited to make trade-offs +specific to their use cases and b) there's no one-size-fits-all solution. As a result, recipes +are meant to be easy to understand, extend and debug, *and not* generalized entry points for +all possible settings. + +Depending on your use case and level of expertise, you will routinely find yourself modifying +existing recipes (eg: adding new features) or writing new ones. torchtune makes writing recipes +easy by providing well-tested modular components/building-blocks and general utilities +(eg: :ref:`WandB Logging` and :ref:`Checkpointing `). + +| + +**Recipe Design** + +Recipes in torchtune are designed to be: + +- **Simple**. Written fully in native-PyTorch. +- **Correct**. Numerical parity verification for every component and extensive comparisons with + reference implementations and benchmarks. +- **Easy to Understand**. Each recipe provides a limited set of meaningful features, instead of + every possible feature hidden behind 100s of flags. Code duplication is preferred over unnecessary + abstractions. +- **Easy to Extend**. No dependency on training frameworks and no implementation inheritance. Users + don't need to go through layers-upon-layers of abstractions to figure out how to extend core + functionality. +- **Accessible to a spectrum of Users**. Users can decide how they want to interact with torchtune recipes: + - Start training models by modifying existing configs + - Modify existing recipes for custom cases + - Directly use available building blocks to write completely new recipes/training paradigms + +Each recipe consists of three components: + +- **Configurable parameters**, specified through yaml configs and command-line overrides +- **Recipe Script**, entry-point which puts everything together including parsing and validating + configs, setting up the environment, and correctly using the recipe class +- **Recipe Class**, core logic needed for training, exposed to users through a set of APIs + +In the following sections, we'll take a closer look at each of these components. +For a complete working example, refer to the +`full finetuning recipe `_ +in torchtune and the associated +`config `_. + +.. TODO (SalmanMohammadi) ref to full finetune recipe doc + +| + +What Recipes are not? +--------------------- + +- **Monolithic Trainers.** A recipe is **not** a monolithic trainer meant to support every + possible feature through 100s of flags. +- **Generalized entry-points.** A recipe is **not** meant to support every possible model + architecture or fine-tuning method. +- **Wrappers around external frameworks.** A recipe is **not** meant to be a wrapper around + external frameworks. These are fully written in native-PyTorch using torchtune building blocks. + Dependencies are primarily in the form of additional utilities or interoperability with the + surrounding ecosystem (eg: EleutherAI's evaluation harness). + +| + +Recipe Script +------------- + +This is the primary entry point for each recipe and provides the user with control over how +the recipe is set up, how models are trained and how the subsequent checkpoints are used. +This includes: + +- Setting up of the environment +- Parsing and validating configs +- Training the model +- Setting up multi-stage training (eg: Distillation) using multiple recipe classes + + +Scripts should generally structure operations in the following order: + +- Initialize the recipe class which in-turn initializes recipe state +- Load and Validate checkpoint to update recipe state if resuming training +- Initialize recipe components (model, tokenizer, optimizer, loss and dataloader) + from checkpoint (if applicable) +- Train the model +- Clean up recipe state after training is complete + + +An example script looks something like this: + +.. code-block:: python + + # Initialize the process group + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + # Setup the recipe and train the model + recipe = FullFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + # Other stuff to do after training is complete + ... + + +Recipe Class +------------ + +The recipe class carries the core logic for training a model. Each class implements a relevant +interface and exposes a set of APIs. For fine-tuning, the structure of this class is as follows: + +Initialize recipe state including seed, device, dtype, metric loggers, relevant flags etc: + +.. code-block:: python + + def __init__(...): + + self._device = utils.get_device(device=params.device) + self._dtype = utils.get_dtype(dtype=params.dtype, device=self._device) + ... + +Load checkpoint, update recipe state from checkpoint, initialize components and load state dicts from checkpoint + +.. code-block:: python + + def setup(self, cfg: DictConfig): + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # Setup the model, including FSDP wrapping, setting up activation checkpointing and + # loading the state dict + self._model = self._setup_model(...) + self._tokenizer = self._setup_tokenizer(...) + + # Setup Optimizer, including transforming for FSDP when resuming training + self._optimizer = self._setup_optimizer(...) + self._loss_fn = self._setup_loss(...) + self._sampler, self._dataloader = self._setup_data(...) + + +Run forward and backward across all epochs and save checkpoint at end of each epoch + +.. code-block:: python + + def train(...): + + self._optimizer.zero_grad() + for curr_epoch in range(self.epochs_run, self.total_epochs): + + for idx, batch in enumerate(self._dataloader): + ... + + with self._autocast: + logits = self._model(...) + ... + loss = self._loss_fn(logits, labels) + + if self.global_step % self._log_every_n_steps == 0: + self._metric_logger.log_dict(...) + + loss.backward() + self._optimizer.step() + self._optimizer.zero_grad() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + self.save_checkpoint(epoch=curr_epoch) + + +Cleanup recipe state + +.. code-block:: python + + def cleanup(...) + + self.metric_loggers.close() + ... + + + +Running Recipes with Configs +---------------------------- + +To run a recipe with a set of user-defined parameters, you will need to write a config file. +You can learn all about configs in our :ref:`config deep-dive`. + +Config and CLI parsing using :code:`parse` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +We provide a convenient decorator :func:`~torchtune.config.parse` that wraps +your recipe to enable running from the command-line with :ref:`tune ` with config +and CLI override parsing. + +.. code-block:: python + + @config.parse + def recipe_main(cfg: DictConfig) -> None: + recipe = FullFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +Running your recipe +^^^^^^^^^^^^^^^^^^^ +You should be able to run your recipe by providing the direct paths to your custom +recipe and custom config using the :ref:`tune ` command with any CLI overrides: + +.. code-block:: bash + + tune run --config k1=v1 k2=v2 ... diff --git a/training/docs/source/deep_dives/wandb_logging.rst b/training/docs/source/deep_dives/wandb_logging.rst new file mode 100644 index 0000000..6af358a --- /dev/null +++ b/training/docs/source/deep_dives/wandb_logging.rst @@ -0,0 +1,97 @@ +.. _wandb_logging: + +=========================== +Logging to Weights & Biases +=========================== + +This deep-dive will guide you through how to set up logging to Weights & Biases +(W&B) in torchtune. + +.. grid:: 1 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What this deep-dive will cover + + * How to get started with W&B + * How to use the :class:`~torchtune.utils.metric_logging.WandBLogger` + * How to log configs, metrics, and model checkpoints to W&B + +Torchtune supports logging your training runs to `Weights & Biases `_. +An example W&B workspace from a torchtune fine-tuning run can be seen in the screenshot below. + +.. image:: ../_static/img/torchtune_workspace.png + :alt: torchtune workspace in W&B + :width: 100% + :align: center + +.. note:: + + You will need to install the :code:`wandb` package to use this feature. + You can install it via pip: + + .. code-block:: bash + + pip install wandb + + Then you need to login with your API key using the W&B CLI: + + .. code-block:: bash + + wandb login + + +Metric Logger +------------- + +The only change you need to make is to add the metric logger to your config. Weights & Biases will log the metrics and model checkpoints for you. + +.. code-block:: yaml + + # enable logging to the built-in WandBLogger + metric_logger: + _component_: torchtune.utils.metric_logging.WandBLogger + # the W&B project to log to + project: torchtune + + +We automatically grab the config from the recipe you are running and log it to W&B. You can find it in the W&B overview tab and the actual file in the :code:`Files` tab. + +As a tip, you may see straggler `wandb` processes running in the background if your job crashes or otherwise exits without cleaning up resources. To kill these straggler processes, a command like ``ps +-aux | grep wandb | awk '{ print $2 }' | xargs kill`` can be used. + +.. note:: + + Click on this sample `project to see the W&B workspace `_. + The config used to train the models can be found `here `_. + +Logging Model Checkpoints to W&B +-------------------------------- + +You can also log the model checkpoints to W&B by modifying the desired script :code:`save_checkpoint` method. + +A suggested approach would be something like this: + +.. code-block:: python + + def save_checkpoint(self, epoch: int) -> None: + ... + ## Let's save the checkpoint to W&B + ## depending on the Checkpointer Class the file will be named differently + ## Here is an example for the full_finetune case + checkpoint_file = Path.joinpath( + self._checkpointer._output_dir, f"torchtune_model_{epoch}" + ).with_suffix(".pt") + wandb_at = wandb.Artifact( + name=f"torchtune_model_{epoch}", + type="model", + # description of the model checkpoint + description="Model checkpoint", + # you can add whatever metadata you want as a dict + metadata={ + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + wandb_at.add_file(checkpoint_file) + wandb.log_artifact(wandb_at) diff --git a/training/docs/source/index.rst b/training/docs/source/index.rst new file mode 100644 index 0000000..1afc580 --- /dev/null +++ b/training/docs/source/index.rst @@ -0,0 +1,136 @@ +Welcome to the torchtune Documentation +======================================= + +**torchtune** is a Native-PyTorch library for LLM fine-tuning. + +Getting Started +~~~~~~~~~~~~~~~ + +Topics in this section will help you get started with torchtune. + +.. grid:: 3 + + .. grid-item-card:: :octicon:`file-code;1em` + What is torchtune? + :img-top: _static/img/card-background.svg + :link: overview.html + :link-type: url + + A gentle introduction to torchtune and how you can + use the library in your projects. + + .. grid-item-card:: :octicon:`file-code;1em` + Installation instructions + :img-top: _static/img/card-background.svg + :link: install.html + :link-type: url + + A step-by-step tutorial on how to install torchtune. + + .. grid-item-card:: :octicon:`file-code;1em` + Finetune your first model + :img-top: _static/img/card-background.svg + :link: tutorials/first_finetune_tutorial.html + :link-type: url + + Follow a simple tutorial to finetune Llama2 with torchtune. + +Tutorials +~~~~~~~~~ + +Ready to experiment? Check out some of the interactive +torchtune tutorials. + +.. customcardstart:: + +.. customcarditem:: + :header: Llama3 in torchtune + :card_description: + :image: _static/img/generic-pytorch-logo.png + :link: tutorials/llama3.html + :tags: finetuning,llama3 + +.. customcarditem:: + :header: Finetuning with LoRA in torchtune + :card_description: Parameter-efficient finetuning of Llama2 using LoRA + :image: _static/img/generic-pytorch-logo.png + :link: tutorials/lora_finetune.html + :tags: finetuning,llama2,lora + +.. customcarditem:: + :header: Understanding QLoRA in torchtune + :card_description: Using QLoRA to quantize base model weights and maximize memory savings + :image: _static/img/generic-pytorch-logo.png + :link: tutorials/qlora_finetune.html + :tags: finetuning,llama2,qlora + +.. customcarditem:: + :header: Finetuning with QAT in torchtune + :card_description: Finetuning of Llama3 using QAT + :image: _static/img/generic-pytorch-logo.png + :link: tutorials/qat_finetune.html + :tags: finetuning,llama3,qat,quantization,evals + +.. customcarditem:: + :header: End-to-End Workflow with torchtune + :card_description: Train, Evaluate, Quantize and then Generate with your LLM. + :image: _static/img/generic-pytorch-logo.png + :link: tutorials/e2e_flow.html + :tags: finetuning,quantization,inference,evals,llama2 + +.. customcardend:: + + +.. ---------------------------------------------------------------------- +.. Below is the toctree i.e. it defines the content of the left sidebar. +.. Each of the entry below corresponds to a file.rst in docs/source/. +.. ---------------------------------------------------------------------- + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Getting Started + :hidden: + + overview + install + tutorials/first_finetune_tutorial + tune_cli + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Tutorials + :hidden: + + tutorials/llama3 + tutorials/lora_finetune + tutorials/qlora_finetune + tutorials/qat_finetune + tutorials/e2e_flow + tutorials/datasets + tutorials/chat + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Deep-Dives + :hidden: + + deep_dives/checkpointer + deep_dives/configs + deep_dives/recipe_deepdive + deep_dives/wandb_logging + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: API Reference + :hidden: + + api_ref_config + api_ref_data + api_ref_datasets + api_ref_models + api_ref_modules + api_ref_utilities diff --git a/training/docs/source/install.rst b/training/docs/source/install.rst new file mode 100644 index 0000000..b53559d --- /dev/null +++ b/training/docs/source/install.rst @@ -0,0 +1,73 @@ +.. _install_label: + +==================== +Install Instructions +==================== + +**Pre-requisites**: torchtune requires PyTorch, so please install for your proper host and environment +using the `Start Locally `_ page. + +Install via PyPI +---------------- + +The latest stable version of torchtune is hosted on PyPI and can be downloaded +with the following command: + +.. code-block:: bash + + pip install torchtune + +To confirm that the package is installed correctly, you can run the following command: + +.. code-block:: bash + + tune + +And should see the following output: + +:: + + usage: tune [-h] {download,ls,cp,run,validate} ... + + Welcome to the torchtune CLI! + + options: + -h, --help show this help message and exit + + ... + +| + +Install via ``git clone`` +------------------------- + +If you want the latest and greatest features from torchtune or if you want to `become a contributor `_, +you can also install the package locally with the following command. + +.. code-block:: bash + + git clone https://github.com/pytorch/torchtune.git + cd torchtune + pip install -e . + +| + +Install nightly build +--------------------- + +torchtune gets built every evening with the latest commits to ``main`` branch. If you want the latest updates +to the package *without* installing via ``git clone``, you can install with the following command: + +.. code-block:: bash + + pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir + +.. note:: + + ``--no-cache-dir`` will direct ``pip`` to not look for a cached version of torchtune, thereby overwriting + your existing torchtune installation. + +If you already have PyTorch installed, torchtune will default to using that version. However, if you want to +use the nightly version of PyTorch, you can append the ``--force-reinstall`` option to the above command. If you +opt for this install method, you will likely need to change the "cpu" suffix in the index url to match your CUDA +version. For example, if you are running CUDA 12, your index url would be "https://download.pytorch.org/whl/nightly/cu121". diff --git a/training/docs/source/overview.rst b/training/docs/source/overview.rst new file mode 100644 index 0000000..4b2c6c7 --- /dev/null +++ b/training/docs/source/overview.rst @@ -0,0 +1,81 @@ +.. _overview_label: + +================== +torchtune Overview +================== + +On this page, we'll walk through an overview of torchtune, including features, key concepts and additional pointers. + +What is torchtune? +------------------ + +torchtune is a PyTorch library for easily authoring, fine-tuning and experimenting with LLMs. The library emphasizes 4 key aspects: + +- **Simplicity and Extensibility**. Native-PyTorch, componentized design and easy-to-reuse abstractions +- **Correctness**. High bar on proving the correctness of components and recipes +- **Stability**. PyTorch just works. So should torchtune +- **Democratizing LLM fine-tuning**. Works out-of-the-box on different hardware + + +torchtune provides: + +- Modular native-PyTorch implementations of popular LLMs +- Interoperability with popular model zoos through checkpoint-conversion utilities +- Training recipes for a variety of fine-tuning techniques +- Integration with `Hugging Face Datasets `_ for training and `EleutherAI's Eval Harness `_ for evaluation +- Support for distributed training using `FSDP `_ +- YAML configs for easily configuring training runs + +Excited? To get started, checkout some of our tutorials, including: + +- our :ref:`quickstart guide ` to finetune your first LLM using torchtune. +- our :ref:`LoRA tutorial ` to learn about parameter-efficient finetuning with torchtune. +- our :ref:`QLoRA tutorial ` to attain maximal memory efficiency with torchtune. + +Key Concepts +------------ + +As you go through the tutorials and code, there are two concepts which will help you better understand and use torchtune. + +**Configs.** YAML files which help you configure training settings (dataset, model, chekckpoint) and +hyperparameters (batch size, learning rate) without modifying code. +See the ":ref:`All About Configs`" deep-dive for more information. + +**Recipes.** Recipes can be thought of +as targeted end-to-end pipelines for training and optionally evaluating LLMs. +Each recipe implements a training method (eg: full fine-tuning) with a set of meaningful +features (eg: FSDP + Activation Checkpointing + Gradient Accumulation + Reduced Precision training) +applied to a given model family (eg: Llama2). See the ":ref:`What Are Recipes?`" deep-dive for more information. + +| + +.. _design_principles_label: + +Design Principles +----------------- + +torchtune embodies `PyTorch’s design philosophy `_, especially "usability over everything else". + +**Native PyTorch** + +torchtune is a native-PyTorch library. While we provide integrations with the surrounding ecosystem (eg: `Hugging Face Datasets `_, +`EleutherAI's Eval Harness `_), all of the core functionality is written in PyTorch. + + +**Simplicity and Extensibility** + +torchtune is designed to be easy to understand, use and extend. + +- Composition over implementation inheritance - layers of inheritance for code re-use makes the code hard to read and extend +- No training frameworks - explicitly outlining the training logic makes it easy to extend for custom use cases +- Code duplication is prefered over unecessary abstractions +- Modular building blocks over monolithic components + + +**Correctness** + +torchtune provides well-tested components with a high bar on correctness. The library will never be the first to provide a feature, but available features will be thoroughly tested. We provide + +- Extensive unit tests to ensure component-level numerical parity with reference implementations +- Checkpoint tests to ensure model-level numerical parity with reference implementations +- Integration tests to ensure recipe-level performance parity with reference implementations on standard benchmarks diff --git a/training/docs/source/tune_cli.rst b/training/docs/source/tune_cli.rst new file mode 100644 index 0000000..2c432b9 --- /dev/null +++ b/training/docs/source/tune_cli.rst @@ -0,0 +1,222 @@ +.. _cli_label: + +============= +torchtune CLI +============= + +This page is the documentation for using the torchtune CLI - a convenient way to +download models, find and copy relevant recipes/configs, and run recipes. It is automatically +available when you install torchtune. + +Getting started +--------------- + +The ``--help`` option will show all the possible commands available through the torchtune CLI, +with a short description of each. + +.. code-block:: bash + + $ tune --help + usage: tune [-h] {download,ls,cp,run,validate} ... + + Welcome to the torchtune CLI! + + options: + -h, --help show this help message and exit + + subcommands: + {download,ls,cp,run,validate} + download Download a model from the Hugging Face Hub. + ls List all built-in recipes and configs + ... + +The ``--help`` option is convenient for getting more details about any command. You can use it anytime to list all +available options and their details. For example, ``tune download --help`` provides more information on how +to download files using the CLI. + +.. _tune_download_label: + +Download a model +---------------- + +The ``tune download `` command downloads any model from the Hugging Face Hub. + +.. list-table:: + :widths: 30 60 + + * - \--output-dir + - Directory in which to save the model. + * - \--output-dir-use-symlinks + - To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be either duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if already exists) or downloaded from the Hub and not cached. + * - \--hf-token + - Hugging Face API token. Needed for gated models like Llama. + * - \--ignore-patterns + - If provided, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights. + +.. code-block:: bash + + $ tune download meta-llama/Meta-Llama-3-8B-Instruct + Successfully downloaded model repo and wrote to the following locations: + ./model/config.json + ./model/README.md + ./model/model-00001-of-00002.bin + ... + + +**Download a gated model** + +A lot of recent large pretrained models released from organizations like Meta or MistralAI require you to agree +to the usage terms and conditions before you are allowed to download their model. If this is the case, you can specify +a Hugging Face access token. + +You can find the access token `here `_. + +.. code-block:: bash + + $ tune download meta-llama/Meta-Llama-3-8B-Instruct --hf-token + Successfully downloaded model repo and wrote to the following locations: + ./model/config.json + ./model/README.md + ./model/model-00001-of-00002.bin + ... + +.. note:: + If you'd prefer, you can also use ``huggingface-cli login`` to permanently login to the Hugging Face Hub on your machine. + The ``tune download`` command will pull the access token from your environment. + +**Specify model files you don't want to download** + +Some checkpoint directories can be very large and it can eat up a lot of bandwith and local storage to download the all of the files every time, even if you might +not need a lot of them. This is especially common when the same checkpoint exists in different formats. You can specify patterns to ignore to prevent downloading files +with matching names. By default we ignore safetensor files, but if you want to include all files you can pass in an empty string. + +.. code-block:: bash + + $ tune download meta-llama/Meta-Llama-3-8B-Instruct --hf-token --ignore-patterns None + Successfully downloaded model repo and wrote to the following locations: + ./model/config.json + ./model/README.md + ./model/model-00001-of-00030.safetensors + ... + +.. note:: + Just because a model can be downloaded does not mean that it will work OOTB with torchtune's + built-in recipes or configs. For a list of supported model families and architectures, see :ref:`models`. + + +.. _tune_ls_label: + +List built-in recipes and configs +--------------------------------- + +The ``tune ls`` command lists out all the built-in recipes and configs within torchtune. + + +.. code-block:: bash + + $ tune ls + RECIPE CONFIG + full_finetune_single_device llama2/7B_full_low_memory + code_llama2/7B_full_low_memory + llama3/8B_full_single_device + mistral/7B_full_low_memory + phi3/mini_full_low_memory + full_finetune_distributed llama2/7B_full + llama2/13B_full + llama3/8B_full + llama3/70B_full + ... + +.. _tune_cp_cli_label: + +Copy a built-in recipe or config +-------------------------------- + +The ``tune cp `` command copies built-in recipes and configs to a provided location. This allows you to make a local copy of a library +recipe or config to edit directly for yourself. See :ref:`here ` for an example of how to use this command. + +.. list-table:: + :widths: 30 60 + + * - \-n, \--no-clobber + - Do not overwrite destination if it already exists + * - \--make-parents + - Create parent directories for destination if they do not exist. If not set to True, will error if parent directories do not exist + +.. code-block:: bash + + $ tune cp lora_finetune_distributed . + Copied file to ./lora_finetune_distributed.py + +Run a recipe +------------ + +The ``tune run --config `` is a wrapper around `torchrun `_. ``tune run`` allows you to specify +a built-in recipe or config by name, or by path to use your local recipes/configs. + +To run a tune recipe + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama3/8B_lora_single_device + +**Specifying distributed (torchrun) arguments** + +``tune run`` supports launching distributed runs by passing through arguments preceding the recipe directly to torchrun. This follows the pattern used by torchrun +of specifying distributed and host machine flags before the script (recipe). For a full list of available flags for distributed setup, see the `torchrun docs `_. + +Some common flags: + +.. list-table:: + :widths: 30 60 + + * - \--nproc-per-node + - Number of workers per node; supported values: [auto, cpu, gpu, int]. + * - \--nnodes + - Number of nodes, or the range of nodes in form :. + * - \--max-restarts + - Maximum number of worker group restarts before failing. + * - \--rdzv-backend + - Rendezvous backend. + * - \--rdzv-endpoint + - Rendezvous backend endpoint; usually in form :. + +.. code-block:: bash + + tune run --nnodes=1 --nproc-per-node=4 lora_finetune_distributed --config llama3/8B_lora + +.. note:: + If no arguments are provided before the recipe, tune will bypass torchrun and launch directly with ``python``. This can simplify running and debugging recipes + when distributed isn't needed. If you want to launch with torchrun, but use only a single device, you can specify ``tune run --nnodes=1 --nproc-per-node=1 --config ``. + +**Running a custom (local) recipe and config** + +To use ``tune run`` with your own local recipes and configs, simply pass in a file path instead of a name to the run command. You can mix and match a custom recipe with a +torchtune config or vice versa or you can use both custom configs and recipes. + +.. code-block:: bash + + tune run my/fancy_lora.py --config my/configs/8B_fancy_lora.yaml + +**Overriding the config** + +You can override existing parameters from the command line using a key=value format. Let’s say you want to set the number of training epochs to 1. +Further information on config overrides can be found :ref:`here `. + +.. code-block:: bash + + tune run --config epochs=1 + +.. _validate_cli_label: + +Validate a config +----------------- + +The ``tune validate `` command will validate that your config is formatted properly. + + +.. code-block:: bash + + # If you've copied over a built-in config and want to validate custom changes + $ tune validate my_configs/llama3/8B_full.yaml + Config is well-formed! diff --git a/training/docs/source/tutorials/README.txt b/training/docs/source/tutorials/README.txt new file mode 100644 index 0000000..e69de29 diff --git a/training/docs/source/tutorials/chat.rst b/training/docs/source/tutorials/chat.rst new file mode 100644 index 0000000..95ca4a6 --- /dev/null +++ b/training/docs/source/tutorials/chat.rst @@ -0,0 +1,341 @@ +================================= +Fine-tuning Llama3 with Chat Data +================================= + +Llama3 Instruct introduced a new prompt template for fine-tuning with chat data. In this tutorial, +we'll cover what you need to know to get you quickly started on preparing your own +custom chat dataset for fine-tuning Llama3 Instruct. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` You will learn: + + * How the Llama3 Instruct format differs from Llama2 + * All about prompt templates and special tokens + * How to use your own chat dataset to fine-tune Llama3 Instruct + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with :ref:`configuring datasets` + * Know how to :ref:`download Llama3 Instruct weights ` + +.. note:: + This tutorial requires a version of torchtune > 0.1.1 + +Template changes from Llama2 to Llama3 +-------------------------------------- + +The Llama2 chat model requires a specific template when prompting the pre-trained +model. Since the chat model was pretrained with this prompt template, if you want to run +inference on the model, you'll need to use the same template for optimal performance +on chat data. Otherwise, the model will just perform standard text completion, which +may or may not align with your intended use case. + +From the `official Llama2 prompt +template guide `_ +for the Llama2 chat model, we can see that special tags are added: + +.. code-block:: text + + [INST] <> + You are a helpful, respectful, and honest assistant. + <> + + Hi! I am a human. [/INST] Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant + +Llama3 Instruct `overhauled `_ +the template from Llama2 to better support multiturn conversations. The same text +in the Llama3 Instruct format would look like this: + +.. code-block:: text + + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are a helpful, respectful, and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> + + Hi! I am a human.<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant<|eot_id|> + +The tags are entirely different, and they are actually encoded differently than in +Llama2. Let's walk through tokenizing an example with the Llama2 template and the +Llama3 template to understand how. + +.. note:: + The Llama3 Base model uses a `different prompt template + `_ than Llama3 Instruct + because it has not yet been instruct tuned and the extra special tokens are untrained. If you + are running inference on the Llama3 Base model without fine-tuning we recommend the base + template for optimal performance. Generally, for instruct and chat data, we recommend using + Llama3 Instruct with its prompt template. The rest of this tutorial assumes you are using + Llama3 Instruct. + + +Tokenizing prompt templates & special tokens +-------------------------------------------- + +Let's say I have a sample of a single user-assistant turn accompanied with a system +prompt: + +.. code-block:: python + + sample = [ + { + "role": "system", + "content": "You are a helpful, respectful, and honest assistant.", + }, + { + "role": "user", + "content": "Who are the most influential hip-hop artists of all time?", + }, + { + "role": "assistant", + "content": "Here is a list of some of the most influential hip-hop " + "artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.", + }, + ] + +Now, let's format this with the :class:`~torchtune.data.Llama2ChatFormat` class and +see how it gets tokenized. The Llama2ChatFormat is an example of a **prompt template**, +which simply structures a prompt with flavor text to indicate a certain task. + +.. code-block:: python + + from torchtune.data import Llama2ChatFormat, Message + + messages = [Message.from_dict(msg) for msg in sample] + formatted_messages = Llama2ChatFormat.format(messages) + print(formatted_messages) + # [ + # Message( + # role='user', + # content='[INST] <>\nYou are a helpful, respectful, and honest assistant.\n<>\n\nWho are the most influential hip-hop artists of all time? [/INST] ', + # ..., + # ), + # Message( + # role='assistant', + # content='Here is a list of some of the most influential hip-hop artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.', + # ..., + # ), + # ] + +There are also special tokens used by Llama2, which are not in the prompt template. +If you look at our :class:`~torchtune.data.Llama2ChatFormat` class, you'll notice that +we don't include the :code:`` and :code:`` tokens. These are the beginning-of-sequence +(BOS) and end-of-sequence (EOS) tokens that are represented differently in the tokenizer +than the rest of the prompt template. Let's tokenize this example with the +:func:`~torchtune.models.llama2.llama2_tokenizer` used by Llama2 to see +why. + +.. code-block:: python + + from torchtune.models.llama2 import llama2_tokenizer + + tokenizer = llama2_tokenizer("/tmp/Llama-2-7b-hf/tokenizer.model") + user_message = formatted_messages[0].text_content + tokens = tokenizer.encode(user_message, add_bos=True, add_eos=True) + print(tokens) + # [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, ..., 2] + +We've added the BOS and EOS tokens when encoding our example text. This shows up +as IDs 1 and 2. We can verify that these are our BOS and EOS tokens. + +.. code-block:: python + + print(tokenizer._spm_model.spm_model.piece_to_id("")) + # 1 + print(tokenizer._spm_model.spm_model.piece_to_id("")) + # 2 + +The BOS and EOS tokens are what we call special tokens, because they have their own +reserved token IDs. This means that they will index to their own individual vectors in +the model's learnt embedding table. The rest of the prompt template tags, :code:`[INST]` +and :code:`<>` are tokenized as normal text and not their own IDs. + +.. code-block:: python + + print(tokenizer.decode(518)) + # '[' + print(tokenizer.decode(25580)) + # 'INST' + print(tokenizer.decode(29962)) + # ']' + print(tokenizer.decode([3532, 14816, 29903, 6778])) + # '<>' + +It's important to note that you should not place the special reserved tokens in your +input prompts manually, as it will be treated as normal text and not as a special +token. + +.. code-block:: python + + print(tokenizer.encode("", add_bos=False, add_eos=False)) + # [529, 29879, 29958] + +Now let's take a look at Llama3's formatting to see how it's tokenized differently +than Llama2. + +.. code-block:: python + + from torchtune.models.llama3 import llama3_tokenizer + + tokenizer = llama3_tokenizer("/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model") + messages = [Message.from_dict(msg) for msg in sample] + tokens, mask = tokenizer.tokenize_messages(messages) + print(tokenizer.decode(tokens)) + # '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful, + # and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho + # are the most influential hip-hop artists of all time?<|eot_id|><|start_header_id|> + # assistant<|end_header_id|>\n\nHere is a list of some of the most influential hip-hop + # artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.<|eot_id|>' + +.. note:: + We used the ``tokenize_messages`` API for Llama3, which is different than + encode. It simply manages adding all the special tokens in the correct + places after encoding the individual messages. + +We can see that the tokenizer handled all the formatting without us specifying a prompt +template. It turns out that all of the additional tags are special tokens, and we don't require +a separate prompt template. We can verify this by checking if the tags get encoded +as their own token IDs. + +.. code-block:: python + + print(tokenizer.special_tokens["<|begin_of_text|>"]) + # 128000 + print(tokenizer.special_tokens["<|eot_id|>"]) + # 128009 + +The best part is - all these special tokens are handled purely by the tokenizer. +That means you won't have to worry about messing up any required prompt templates! + + +When should I use a prompt template? +------------------------------------ + +Whether or not to use a prompt template is governed by what your desired inference +behavior is. You should use a prompt template if you are running inference on the +base model and it was pre-trained with a prompt template, or you want to prime a +fine-tuned model to expect a certain prompt structure on inference for a specific task. + +It is not strictly necessary to fine-tune with a prompt template, but generally +specific tasks will require specific templates. For example, the :class:`~torchtune.data.SummarizeTemplate` +provides a lightweight structure to prime your fine-tuned model for prompts asking to summarize text. +This would wrap around the user message, with the assistant message untouched. + +.. code-block:: python + + f"Summarize this dialogue:\n{dialogue}\n---\nSummary:\n" + +You can fine-tune Llama2 with this template even though the model was originally pre-trained +with the :class:`~torchtune.data.Llama2ChatFormat`, as long as this is what the model +sees during inference. The model should be robust enough to adapt to a new template. + + +Fine-tuning on a custom chat dataset +------------------------------------ + +Let's test our understanding by trying to fine-tune the Llama3-8B instruct model with a custom +chat dataset. We'll walk through how to set up our data so that it can be tokenized +correctly and fed into our model. + +Let's say we have a local dataset saved as a CSV file that contains questions +and answers from an online forum. How can we get something like this into a format +Llama3 understands and tokenizes correctly? + +.. code-block:: python + + import pandas as pd + + df = pd.read_csv('your_file.csv', nrows=1) + print("Header:", df.columns.tolist()) + # ['input', 'output'] + print("First row:", df.iloc[0].tolist()) + # [ + # "How do GPS receivers communicate with satellites?", + # "The first thing to know is the communication is one-way...", + # ] + +The Llama3 tokenizer class, :class:`~torchtune.models.llama3._tokenizer.Llama3Tokenizer`, +expects the input to be in the :class:`~torchtune.data.Message` format. Let's +quickly write a function that can parse a single row from our csv file into +the Message dataclass. The function also needs to have a train_on_input parameter. + +.. code-block:: python + + def message_converter(sample: Mapping[str, Any], train_on_input: bool) -> List[Message]: + input_msg = sample["input"] + output_msg = sample["output"] + + user_message = Message( + role="user", + content=input_msg, + masked=not train_on_input, # Mask if not training on prompt + ) + assistant_message = Message( + role="assistant", + content=output_msg, + masked=False, + ) + # A single turn conversation + messages = [user_message, assistant_message] + + return messages + +Since we're fine-tuning Llama3, the tokenizer will handle formatting the prompt for +us. But if we were fine-tuning a model that requires a template, for example the +Mistral-7B model which uses the :class:`~torchtune.models.mistral._tokenizer.MistralTokenizer`, +we would need to use a chat format like :class:`~torchtune.data.MistralChatFormat` to format +all messages according to their `recommendations `_. + +Now let's create a builder function for our dataset that loads in our local file, +converts to a list of Messages using our function, and creates a :class:`~torchtune.datasets.ChatDataset` +object. + +.. code-block:: python + + def custom_dataset( + *, + tokenizer: ModelTokenizer, + max_seq_len: int = 2048, # You can expose this if you want to experiment + ) -> ChatDataset: + + return ChatDataset( + tokenizer=tokenizer, + # For local csv files, we specify "csv" as the source, just like in + # load_dataset + source="csv", + # Default split of "train" is required for local files + split="train", + convert_to_messages=message_converter, + # Llama3 does not need a chat format + chat_format=None, + max_seq_len=max_seq_len, + # To load a local file we specify it as data_files just like in + # load_dataset + data_files="your_file.csv", + ) + +.. note:: + You can pass in any keyword argument for `load_dataset `_ into all our + Dataset classes and they will honor them. This is useful for common parameters + such as specifying the data split with :code:`split` or configuration with + :code:`name` + +Now we're ready to start fine-tuning! We'll use the built-in LoRA single device recipe. +Use the :ref:`tune cp ` command to get a copy of the :code:`8B_lora_single_device.yaml` +config and update it to use your new dataset. Create a new folder for your project +and make sure the dataset builder and message converter are saved in that directory, +then specify it in the config. + +.. code-block:: yaml + + dataset: + _component_: path.to.my.custom_dataset + max_seq_len: 2048 + +Launch the fine-tune! + +.. code-block:: bash + + $ tune run lora_finetune_single_device --config custom_8B_lora_single_device.yaml epochs=15 diff --git a/training/docs/source/tutorials/datasets.rst b/training/docs/source/tutorials/datasets.rst new file mode 100644 index 0000000..a8de278 --- /dev/null +++ b/training/docs/source/tutorials/datasets.rst @@ -0,0 +1,567 @@ +.. _dataset_tutorial_label: + +==================================== +Configuring Datasets for Fine-Tuning +==================================== + +This tutorial will guide you through how to set up a dataset to fine-tune on. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How to quickly get started with built-in datasets + * How to use any dataset from Hugging Face Hub + * How to use instruct, chat, or text completion datasets + * How to configure datasets from code, config, or command-line + * How to fully customize your own dataset + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Know how to :ref:`configure components from the config` + +Datasets are a core component of fine-tuning workflows that serve as a "steering +wheel" to guide LLM generation for a particular use case. Many publicly shared +open-source datasets have become popular for fine-tuning LLMs and serve as a great +starting point to train your model. torchtune gives you the tools to download external +community datasets, load in custom local datasets, or create your own datasets. + +Built-in datasets +----------------- + +To use one of the built-in datasets in the library, simply import and call the dataset builder +function. You can see a list of all supported datasets :ref:`here`. + +.. code-block:: python + + from torchtune.datasets import alpaca_dataset + + # Load in tokenizer + tokenizer = ... + dataset = alpaca_dataset(tokenizer) + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.alpaca_dataset + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset=torchtune.datasets.alpaca_dataset + +Hugging Face datasets +--------------------- + +We provide first class support for datasets on the Hugging Face hub. Under the hood, +all of our built-in datasets and dataset builders are using Hugging Face's `load_dataset() `_ +to load in your data, whether local or on the hub. + +You can pass in a Hugging Face dataset path to the ``source`` parameter in any of our builders +to specify which dataset on the hub to download or use from a local directory path (see `Local and remote datasets`_). Additionally, all builders accept +any keyword-arguments that ``load_dataset()`` supports. You can see a full list +on Hugging Face's `documentation. `_ + +.. code-block:: python + + from torchtune.datasets import text_completion_dataset + + # Load in tokenizer + tokenizer = ... + dataset = text_completion_dataset( + tokenizer, + source="allenai/c4", + # Keyword-arguments that are passed into load_dataset + split="train", + data_dir="realnewslike", + ) + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.text_completion_dataset + source: allenai/c4 + split: train + data_dir: realnewslike + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset=torchtune.datasets.text_completion_dataset dataset.source=allenai/c4 \ + dataset.split=train dataset.data_dir=realnewslike + +Setting max sequence length +--------------------------- + +The default collator, :func:`~torchtune.utils.padded_collate`, used in all +our training recipes will pad samples to the max sequence length within the batch, +not globally. If you wish to set an upper limit on the max sequence length globally, +you can specify it in the dataset builder with ``max_seq_len``. Any sample in the dataset +that is longer than ``max_seq_len`` will be truncated in :func:`~torchtune.data.truncate`. +The tokenizer's EOS ids are ensured to be the last token, except in :class:`~torchtune.datasets.TextCompletionDataset`. + +Generally, you want the max sequence length returned in each data sample to match the context window +size of your model. You can also decrease this value to reduce memory usage +depending on your hardware constraints. + +.. code-block:: python + + from torchtune.datasets import alpaca_dataset + + # Load in tokenizer + tokenizer = ... + dataset = alpaca_dataset( + tokenizer=tokenizer, + max_seq_len=4096, + ) + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.alpaca_dataset + max_seq_len: 4096 + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset.max_seq_len=4096 + +Sample packing +-------------- + +You can use sample packing with any of the single dataset builders by passing in +:code:`packed=True`. This requires some pre-processing of the dataset which may +slow down time-to-first-batch, but can introduce significant training speedups +depending on the dataset. + +.. code-block:: python + + from torchtune.datasets import alpaca_dataset, PackedDataset + + # Load in tokenizer + tokenizer = ... + dataset = alpaca_dataset( + tokenizer=tokenizer, + packed=True, + ) + print(isinstance(dataset, PackedDataset)) # True + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.alpaca_dataset + packed: True + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset.packed=True + + +Custom unstructured text corpus +------------------------------- + +For continued pre-training, typically a similar data setup to pre-training is used +for a simple text completion task. This means no instruct templates, chat formats, +and minimal special tokens (only BOS and, optionally, EOS). To specify an unstructured text corpus, +you can use the :func:`~torchtune.datasets.text_completion_dataset` builder with +a Hugging Face dataset or a custom local corpus. Here is how to specify it for local +files: + +.. code-block:: python + + from torchtune.datasets import text_completion_dataset + + # Load in tokenizer + tokenizer = ... + dataset = text_completion_dataset( + tokenizer, + source="text", + data_files="path/to/my_data.txt", + split="train", + ) + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.text_completion_dataset + source: text + data_files: path/to/my_data.txt + split: train + +.. code-block:: bash + + # Command line + tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ + dataset=torchtune.datasets.text_completion_dataset dataset.source=text \ + dataset.data_files=path/to/my_data.txt dataset.split=train + +Custom instruct dataset and instruct templates +---------------------------------------------- + +If you have a custom instruct dataset that's not already provided in the library, +you can use the :func:`~torchtune.datasets.instruct_dataset` builder and specify +the source path. Instruct datasets typically have multiple columns with text that +are formatted into a prompt template. + +To fine-tune an LLM on a particular task, a common approach is to create a fixed instruct +template that guides the model to generate output with a specific goal. Instruct templates +are simply flavor text that structures your inputs for the model. It is model agnostic +and is tokenized normally just like any other text, but it can help condition the model +to respond better to an expected format. For example, the :class:`~torchtune.data.AlpacaInstructTemplate` +structures the data in the following way: + +.. code-block:: python + + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + +Here is an example of a sample that is formatted with :class:`~torchtune.data.AlpacaInstructTemplate`: + +.. code-block:: python + + from torchtune.data import AlpacaInstructTemplate + + sample = { + "instruction": "Classify the following into animals, plants, and minerals", + "input": "Oak tree, copper ore, elephant", + } + prompt = AlpacaInstructTemplate.format(sample) + print(prompt) + # Below is an instruction that describes a task, paired with an input that provides further context. + # Write a response that appropriately completes the request. + # + # ### Instruction: + # Classify the following into animals, plants, and minerals + # + # ### Input: + # Oak tree, copper ore, elephant + # + # ### Response: + # + +We provide :ref:`other instruct templates ` +for common tasks such summarization and grammar correction. If you need to create your own +instruct template for a custom task, you can inherit from :class:`~torchtune.data.InstructTemplate` +and create your own class. + +.. code-block:: python + + from torchtune.datasets import instruct_dataset + from torchtune.data import InstructTemplate + + class CustomTemplate(InstructTemplate): + # Define the template as string with {} as placeholders for data columns + template = ... + + # Implement this method + @classmethod + def format( + cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None + ) -> str: + ... + + # Load in tokenizer + tokenizer = ... + dataset = instruct_dataset( + tokenizer=tokenizer, + source="my/dataset/path", + template="import.path.to.CustomTemplate", + ) + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.instruct_dataset + source: my/dataset/path + template: import.path.to.CustomTemplate + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset=torchtune.datasets.instruct_dataset dataset.source=my/dataset/path \ + dataset.template=import.path.to.CustomTemplate + + +torchtune uses :code:`importlib.import_module` (see ``importlib`` `docs `_ for more details) +to locate components from their dotpaths. You can place your custom template class +in any Python file as long as the file is accessible by Python's import mechanism. +This means the module should be in a directory that is included in Python's search +paths (:code:`sys.path`). This often includes: + +- The current directory from which your Python interpreter or script is run. +- Directories where Python packages are installed (like :code:`site-packages`). +- Any directories added to :code:`sys.path` at runtime using :code:`sys.path.append` or through the :code:`PYTHONPATH` environment variable. + + +Custom chat dataset and chat formats +------------------------------------ + +If you have a custom chat/conversational dataset that's not already provided in the library, +you can use the :func:`~torchtune.datasets.chat_dataset` builder and specify +the source path. Chat datasets typically have a single column with multiple back +and forth messages between the user and assistant. + +Chat formats are similar to instruct templates, except that they format system, +user, and assistant messages into a list of messages (see :class:`~torchtune.data.ChatFormat`) +for a conversational dataset. These can be configured quite similarly to instruct +datasets. + +Here is how messages would be formatted using the :class:`~torchtune.data.Llama2ChatFormat`: + +.. code-block:: python + + from torchtune.data import Llama2ChatFormat, Message + + messages = [ + Message( + role="system", + content="You are a helpful, respectful, and honest assistant.", + ), + Message( + role="user", + content="I am going to Paris, what should I see?", + ), + Message( + role="assistant", + content="Paris, the capital of France, is known for its stunning architecture..." + ), + ] + formatted_messages = Llama2ChatFormat.format(messages) + print(formatted_messages) + # [ + # Message( + # role="user", + # content="[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\n" + # "I am going to Paris, what should I see? [/INST] ", + # ), + # Message( + # role="assistant", + # content="Paris, the capital of France, is known for its stunning architecture..." + # ), + # ] + +Note that the system message is now incorporated in the user message. If you create custom ChatFormats +you can also add more advanced behavior. + +.. code-block:: python + + from torchtune.datasets import chat_dataset + from torchtune.data import ChatFormat + + class CustomChatFormat(ChatFormat): + # Define templates for system, user, assistant messages + # as strings with {} as placeholders for message content + system = ... + user = ... + assistant = ... + + # Implement this method + @classmethod + def format( + cls, + sample: List[Message], + ) -> List[Message]: + ... + + # Load in tokenizer + tokenizer = ... + dataset = chat_dataset( + tokenizer=tokenizer, + source="my/dataset/path", + split="train", + conversation_style="openai", + chat_format="import.path.to.CustomChatFormat", + ) + +.. code-block:: yaml + + # YAML config + dataset: + _component_: torchtune.datasets.chat_dataset + source: my/dataset/path + conversation_style: openai + chat_format: import.path.to.CustomChatFormat + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset=torchtune.datasets.chat_dataset dataset.source=my/dataset/path \ + dataset.conversation_style=openai dataset.chat_format=import.path.to.CustomChatFormat + + +Multiple in-memory datasets +--------------------------- + +It is also possible to train on multiple datasets and configure them individually using +our :class:`~torchtune.datasets.ConcatDataset` interface. You can even mix instruct and chat datasets +or other custom datasets. + +.. code-block:: yaml + + # YAML config + dataset: + - _component_: torchtune.datasets.instruct_dataset + source: vicgalle/alpaca-gpt4 + template: torchtune.data.AlpacaInstructTemplate + split: train + train_on_input: True + - _component_: torchtune.datasets.instruct_dataset + source: samsum + template: torchtune.data.SummarizeTemplate + column_map: + output: summary + split: train + train_on_input: False + - _component_: torchtune.datasets.chat_dataset + ... + + +Local and remote datasets +------------------------- + +To use a dataset saved on your local hard drive, simply specify the file type for +``source`` and pass in the ``data_files`` argument using any of the dataset +builder functions. We support all `file types `_ +supported by Hugging Face's ``load_dataset``, including csv, json, txt, and more. + +.. code-block:: python + + from torchtune.datasets import instruct_dataset + + # Load in tokenizer + tokenizer = ... + # Local files + dataset = instruct_dataset( + tokenizer=tokenizer, + source="csv", + split="train", + template="import.path.to.CustomTemplate" + data_files="path/to/my/data.csv", + ) + # Remote files + dataset = instruct_dataset( + tokenizer=tokenizer, + source="json", + split="train", + template="import.path.to.CustomTemplate" + data_files="https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", + # You can also pass in any kwarg that load_dataset accepts + field="data", + ) + +.. code-block:: yaml + + # YAML config - local files + dataset: + _component_: torchtune.datasets.instruct_dataset + source: csv + template: import.path.to.CustomTemplate + data_files: path/to/my/data.csv + + # YAML config - remote files + dataset: + _component_: torchtune.datasets.instruct_dataset + source: json + template: import.path.to.CustomTemplate + data_files: https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json + field: data + +.. code-block:: bash + + # Command line - local files + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset=torchtune.datasets.chat_dataset dataset.source=csv \ + dataset.template=import.path.to.CustomTemplate dataset.data_files=path/to/my/data.csv + +Fully customized datasets +------------------------- + +More advanced tasks and dataset formats that don't fit into the templating and processing +that :class:`~torchtune.datasets.InstructDataset`, :class:`~torchtune.datasets.ChatDataset`, +and :class:`~torchtune.datasets.TextCompletionDataset` provide may require you to create your own dataset +class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`, +which has custom functionality for RLHF preference data, as an example to understand what you'll need to do. + +If you take a look at the code for the :class:`~torchtune.datasets.PreferenceDataset` class, +you'll notice it's quite similar to :class:`~torchtune.datasets.InstructDataset` with a few +adjustments for chosen and rejected samples in preference data. + +.. code-block:: python + + chosen_message = [ + Message(role="user", content=prompt, masked=True), + Message(role="assistant", content=transformed_sample[key_chosen]), + ] + rejected_message = [ + Message(role="user", content=prompt, masked=True), + Message(role="assistant", content=transformed_sample[key_rejected]), + ] + + chosen_input_ids, c_masks = self._tokenizer.tokenize_messages( + chosen_message, self.max_seq_len + ) + chosen_labels = list( + np.where(c_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids) + ) + + rejected_input_ids, r_masks = self._tokenizer.tokenize_messages( + rejected_message, self.max_seq_len + ) + rejected_labels = list( + np.where(r_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids) + ) + +For a specific dataset that's easy to customize from the config, you can create +a builder function. This is the builder function for the :func:`~torchtune.datasets.stack_exchanged_paired_dataset`, +which creates a :class:`~torchtune.datasets.PreferenceDataset` configured to use +a paired dataset from Hugging Face. Notice that we've also had +to add a custom instruct template as well. + +.. code-block:: python + + def stack_exchanged_paired_dataset( + tokenizer: ModelTokenizer, + max_seq_len: int = 1024, + ) -> PreferenceDataset: + return PreferenceDataset( + tokenizer=tokenizer, + source="lvwerra/stack-exchange-paired", + template=StackExchangedPairedTemplate(), + column_map={ + "prompt": "question", + "chosen": "response_j", + "rejected": "response_k", + }, + max_seq_len=max_seq_len, + split="train", + data_dir="data/rl", + ) + +Now we can easily specify our custom dataset from the config, or from command-line. + +.. code-block:: yaml + + # This is how you would configure the Alpaca dataset using the builder + dataset: + _component_: torchtune.datasets.stack_exchanged_paired_dataset + max_seq_len: 512 + +.. code-block:: bash + + # Command line - local files + tune run full_finetune_single_device --config llama3/8B_full_single_device \ + dataset=torchtune.datasets.stack_exchanged_paired_dataset dataset.max_seq_len=512 diff --git a/training/docs/source/tutorials/e2e_flow.rst b/training/docs/source/tutorials/e2e_flow.rst new file mode 100644 index 0000000..41eec69 --- /dev/null +++ b/training/docs/source/tutorials/e2e_flow.rst @@ -0,0 +1,450 @@ +.. _e2e_flow: + +================================== +End-to-End Workflow with torchtune +================================== + +In this tutorial, we'll walk through an end-to-end example of how you can fine-tune, +evaluate, optionally quantize and then run generation with your favorite LLM using +torchtune. We'll also go over how you can use some popular tools and libraries +from the community seemlessly with torchtune. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What this tutorial will cover: + + * Different type of recipes available in torchtune beyond fine-tuning + * End-to-end example connecting all of these recipes + * Different tools and libraries you can use with torchtune + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with the :ref:`overview of torchtune` + * Make sure to :ref:`install torchtune` + * Concepts such as :ref:`configs ` and + :ref:`checkpoints ` + + +Overview +-------- + +Fine-tuning an LLM is usually only one step in a larger workflow. An example workflow that you +might have can look something like this: + +- Download a popular model from `HF Hub `_ +- Fine-tune the model using a relevant fine-tuning technique. The exact technique used + will depend on factors such as the model, amount and nature of training data, your hardware + setup and the end task for which the model will be used +- Evaluate the model on some benchmarks to validate model quality +- Run some generations to make sure the model output looks reasonable +- Quantize the model for efficient inference +- [Optional] Export the model for specific environments such as inference on a mobile phone + +In this tutorial, we'll cover how you can use torchtune for all of the above, leveraging +integrations with popular tools and libraries from the ecosystem. + +We'll use the Llama2 7B model for this tutorial. You can find a complete set of models supported +by torchtune `here `_. + +| + +Download Llama2 7B +------------------ + +In this tutorial, we'll use the Hugging Face model weights for the Llama2 7B mode. +For more information on checkpoint formats and how these are handled in torchtune, take a look at +this tutorial on :ref:`checkpoints `. + +To download the HF format Llama2 7B model, we'll use the tune CLI. + +.. code-block:: bash + + tune download \ + meta-llama/Llama-2-7b-hf \ + --output-dir \ + --hf-token + +Make a note of ````, we'll use this many times in this tutorial. + +| + +Finetune the model using LoRA +----------------------------- + +For this tutorial, we'll fine-tune the model using LoRA. LoRA is a parameter efficient fine-tuning +technique which is especially helpful when you don't have a lot of GPU memory to play with. LoRA +freezes the base LLM and adds a very small percentage of learnable parameters. This helps keep +memory associated with gradients and optimizer state low. Using torchtune, you should be able to +fine-tune a Llama2 7B model with LoRA in less than 16GB of GPU memory using bfloat16 on a +RTX 3090/4090. For more information on how to use LoRA, take a look at our +:ref:`LoRA Tutorial `. + +We'll fine-tune using our +`single device LoRA recipe `_ +and use the standard settings from the +`default config `_. + +This will fine-tune our model using a ``batch_size=2`` and ``dtype=bfloat16``. With these settings the model +should have a peak memory usage of ~16GB and total training time of around two hours for each epoch. +We'll need to make some changes to the config to make sure our recipe can access the +right checkpoints. + +Let's look for the right config for this use case by using the tune CLI. + +.. code-block:: bash + + tune ls + + RECIPE CONFIG + full_finetune_single_device llama2/7B_full_low_memory + mistral/7B_full_low_memory + full_finetune_distributed llama2/7B_full + llama2/13B_full + mistral/7B_full + lora_finetune_single_device llama2/7B_lora_single_device + llama2/7B_qlora_single_device + mistral/7B_lora_single_device + ... + + +For this tutorial we'll use the ``llama2/7B_lora_single_device`` config. + +The config already points to the HF Checkpointer and the right checkpoint files. +All we need to do is update the checkpoint directory for both the model and the +tokenizer. Let's do this using the overrides in the tune CLI while starting training! + + +.. code-block:: bash + + tune run lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ + checkpointer.checkpoint_dir= \ + tokenizer.path=/tokenizer.model \ + checkpointer.output_dir= + + +Once training is complete, you'll see the following in the logs. + +.. code-block:: bash + + [_checkpointer.py:473] Model checkpoint of size 9.98 GB saved to /hf_model_0001_0.pt + + [_checkpointer.py:473] Model checkpoint of size 3.50 GB saved to /hf_model_0002_0.pt + + [_checkpointer.py:484] Adapter checkpoint of size 0.01 GB saved to /adapter_0.pt + + +The final trained weights are merged with the original model and split across two checkpoint files +similar to the source checkpoints from the HF Hub +(see the :ref:`LoRA Tutorial ` for more details). +In fact the keys will be identical between these checkpoints. +We also have a third checkpoint file which is much smaller in size +and contains the learnt LoRA adapter weights. For this tutorial, we'll only use the model +checkpoints and not the adapter weights. + +| + +.. _eval_harness_label: + +Run Evaluation using EleutherAI's Eval Harness +---------------------------------------------- + +We've fine-tuned a model. But how well does this model really do? Let's run some Evaluations! + +.. TODO (SalmanMohammadi) ref eval recipe docs + +torchtune integrates with +`EleutherAI's evaluation harness `_. +An example of this is available through the +``eleuther_eval`` recipe. In this tutorial, we're going to directly use this recipe by +modifying its associated config ``eleuther_evaluation.yaml``. + +.. note:: + For this section of the tutorial, you should first run :code:`pip install lm_eval==0.4.*` + to install the EleutherAI evaluation harness. + +Since we plan to update all of the checkpoint files to point to our fine-tuned checkpoints, +let's first copy over the config to our local working directory so we can make changes. This +will be easier than overriding all of these elements through the CLI. + +.. code-block:: bash + + tune cp eleuther_evaluation ./custom_eval_config.yaml \ + +For this tutorial we'll use the `truthfulqa_mc2 `_ task from the harness. +This task measures a model's propensity to be truthful when answering questions and +measures the model's zero-shot accuracy on a question followed by one or more true +responses and one or more false responses. Let's first run a baseline without fine-tuning. + + +.. code-block:: bash + + tune run eleuther_eval --config ./custom_eval_config.yaml + checkpointer.checkpoint_dir= \ + tokenizer.path=/tokenizer.model + + [evaluator.py:324] Running loglikelihood requests + [eleuther_eval.py:195] Eval completed in 121.27 seconds. + [eleuther_eval.py:197] truthfulqa_mc2: {'acc,none': 0.388... + +The model has an accuracy around 38.8%. Let's compare this with the fine-tuned model. + + +First, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoints. + +.. code-block:: yaml + + checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + + # directory with the checkpoint files + # this should match the output_dir specified during + # finetuning + checkpoint_dir: + + # checkpoint files for the fine-tuned model. This should + # match what's shown in the logs above + checkpoint_files: [ + hf_model_0001_0.pt, + hf_model_0002_0.pt, + ] + + output_dir: + model_type: LLAMA2 + + # Make sure to update the tokenizer path to the right + # checkpoint directory as well + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tokenizer.model + + +Now, let's run the recipe. + +.. code-block:: bash + + tune run eleuther_eval --config ./custom_eval_config.yaml + + +The results should look something like this. + +.. code-block:: bash + + [evaluator.py:324] Running loglikelihood requests + [eleuther_eval.py:195] Eval completed in 121.27 seconds. + [eleuther_eval.py:197] truthfulqa_mc2: {'acc,none': 0.489 ... + +Our fine-tuned model gets ~48% on this task, which is ~10 points +better than the baseline. Great! Seems like our fine-tuning helped. + +| + +Generation +----------- + +We've run some evaluations and the model seems to be doing well. But does it really +generate meaningful text for the prompts you care about? Let's find out! + +For this, we'll use the +`generate recipe `_ +and the associated +`config `_. + + +Let's first copy over the config to our local working directory so we can make changes. + +.. code-block:: bash + + tune cp generation ./custom_generation_config.yaml + +Let's modify ``custom_generation_config.yaml`` to include the following changes. + +.. code-block:: yaml + + checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + + # directory with the checkpoint files + # this should match the output_dir specified during + # finetuning + checkpoint_dir: + + # checkpoint files for the fine-tuned model. This should + # match what's shown in the logs above + checkpoint_files: [ + hf_model_0001_0.pt, + hf_model_0002_0.pt, + ] + + output_dir: + model_type: LLAMA2 + + # Make sure to update the tokenizer path to the right + # checkpoint directory as well + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tokenizer.model + + +Once the config is updated, let's kick off generation! We'll use the +default settings for sampling with ``top_k=300`` and a +``temperature=0.8``. These parameters control how the probabilities for +sampling are computed. These are standard settings for Llama2 7B and +we recommend inspecting the model with these before playing around with +these parameters. + +We'll use a different prompt from the one in the config + +.. code-block:: bash + + tune run generate --config ./custom_generation_config.yaml \ + prompt="What are some interesting sites to visit in the Bay Area?" + + +Once generation is complete, you'll see the following in the logs. + + +.. code-block:: bash + + [generate.py:92] Exploratorium in San Francisco has made the cover of Time Magazine, + and its awesome. And the bridge is pretty cool... + + [generate.py:96] Time for inference: 11.61 sec total, 25.83 tokens/sec + [generate.py:99] Memory used: 15.72 GB + + +Indeed, the bridge is pretty cool! Seems like our LLM knows a little something about the +Bay Area! + +| + +Speeding up Generation using Quantization +----------------------------------------- + +We rely on `torchao `_ for `post-training quantization `_. +To quantize the fine-tuned model after installing torchao we can run the following command:: + + # we also support `int8_weight_only()` and `int8_dynamic_activation_int8_weight()`, see + # https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + # for a full list of techniques that we support + from torchao.quantization.quant_api import quantize\_, int4_weight_only + quantize\_(model, int4_weight_only()) + +After quantization, we rely on torch.compile for speedups. For more details, please see `this example usage `_. + +torchao also provides `this table `_ listing performance and accuracy results for ``llama2`` and ``llama3``. + +For Llama models, you can run generation directly in torchao on the quantized model using their ``generate.py`` script as +discussed in `this readme `_. This way you can compare your own results +to those in the previously-linked table. + +| + +Using torchtune checkpoints with other libraries +------------------------------------------------ + +As we mentioned above, one of the benefits of handling of the checkpoint +conversion is that you can directly work with standard formats. This helps +with interoperability with other libraries since torchtune doesn't add yet +another format to the mix. + +Let's take a look at an example of how this would work with a popular codebase +used for running performant inference with LLMs - +`gpt-fast `_. This section +assumes that you've cloned that repository on your machine. + +``gpt-fast`` makes some assumptions about the checkpoint and the availability of +the key-to-file mapping i.e. a file mapping parameter names to the files containing them. +Let's satisfy these assumptions, by creating this mapping +file. Let's assume we'll be using ``/Llama-2-7B-hf`` as the directory +for this. ``gpt-fast`` assumes that the directory with checkpoints has the +same format at the HF repo-id. + +.. code-block:: python + + import json + import torch + + # create the output dictionary + output_dict = {"weight_map": {}} + + # Load the checkpoints + sd_1 = torch.load('/hf_model_0001_0.pt', mmap=True, map_location='cpu') + sd_2 = torch.load('/hf_model_0002_0.pt', mmap=True, map_location='cpu') + + # create the weight map + for key in sd_1.keys(): + output_dict['weight_map'][key] = "hf_model_0001_0.pt" + for key in sd_2.keys(): + output_dict['weight_map'][key] = "hf_model_0002_0.pt" + + with open('/Llama-2-7B-hf/pytorch_model.bin.index.json', 'w') as f: + json.dump(output_dict, f) + + +Now that we've created the weight_map, let's copy over our checkpoints. + +.. code-block:: bash + + cp /hf_model_0001_0.pt /Llama-2-7B-hf/ + cp /hf_model_0002_0.pt /Llama-2-7B-hf/ + cp /tokenizer.model /Llama-2-7B-hf/ + +Once the directory structure is setup, let's convert the checkpoints and run inference! + +.. code-block:: bash + + cd gpt-fast/ + + # convert the checkpoints into a format readable by gpt-fast + python scripts/convert_hf_checkpoint.py \ + --checkpoint_dir /Llama-2-7B-hf/ \ + --model 7B + + # run inference using the converted model + python generate.py \ + --compile \ + --checkpoint_path /Llama-2-7B-hf/model.pth \ + --device cuda + +The output should look something like this: + +.. code-block:: bash + + Hello, my name is Justin. I am a middle school math teacher + at WS Middle School ... + + Time for inference 5: 1.94 sec total, 103.28 tokens/sec + Bandwidth achieved: 1391.84 GB/sec + + +And thats it! Try your own prompt! + +Uploading your model to the Hugging Face Hub +-------------------------------------------- + +Your new model is working great and you want to share it with the world. The easiest way to do this +is utilizing the `huggingface-cli `_ command, which works seamlessly with torchtune. Simply point the CLI +to your finetuned model directory like so: + +.. code-block:: bash + + huggingface-cli upload + +The command should output a link to your repository on the Hub. If the repository doesn't exist yet, it will +be created automatically: + +.. code-block:: text + + https://huggingface.co//tree/main/. + +.. note:: + + Before uploading, make sure you are `authenticated with Hugging Face `_ by running ``huggingface-cli login``. + +For more details on the ``huggingface-cli upload`` feature check out the `Hugging Face docs `_. + +| + +Hopefully this tutorial gave you some insights into how you can use torchtune for +your own workflows. Happy Tuning! diff --git a/training/docs/source/tutorials/first_finetune_tutorial.rst b/training/docs/source/tutorials/first_finetune_tutorial.rst new file mode 100644 index 0000000..e476dbf --- /dev/null +++ b/training/docs/source/tutorials/first_finetune_tutorial.rst @@ -0,0 +1,169 @@ +.. _finetune_llama_label: + +======================== +Fine-Tune Your First LLM +======================== + +This guide will walk you through the process of launching your first finetuning +job using torchtune. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How to download a model from the `Hugging Face Hub `_ + * How to modify a recipe's parameters to suit your needs + * How to run a finetune + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with the :ref:`overview of torchtune` + * Make sure to :ref:`install torchtune` + +.. _download_llama_label: + +Downloading a model +------------------- +The first step in any finetuning job is to download a pretrained base model. torchtune supports an integration +with the `Hugging Face Hub `_ - a collection of the latest and greatest model weights. + +For this tutorial, you're going to use the `Llama2 7B model from Meta `_. Llama2 is a "gated model", +meaning that you need to be granted access in order to download the weights. Follow `these instructions `_ on the official Meta page +hosted on Hugging Face to complete this process. This should take less than 5 minutes. To verify that you have the access, go to the `model page `_. +You should be able to see the model files. If not, you may need to accept the agreement to complete the process. + +.. note:: + + Alternatively, you can opt to download the model directly through the Llama2 repository. + See `this page `_ for more details. + +Once you have authorization, you will need to authenticate with Hugging Face Hub. The easiest way to do so is to provide an +access token to the download script. You can find your token `here `_. + +Then, it's as simple as: + +.. code-block:: bash + + tune download meta-llama/Llama-2-7b-hf \ + --output-dir /tmp/Llama-2-7b-hf \ + --hf-token + +This command will also download the model tokenizer and some other helpful files such as a Responsible Use guide. + +| + +Selecting a recipe +------------------ +Recipes are the primary entry points for torchtune users. +These can be thought of as **hackable, singularly-focused scripts for interacting with LLMs** including training, +inference, evaluation, and quantization. + +Each recipe consists of three components: + +* **Configurable parameters**, specified through yaml configs and command-line overrides +* **Recipe script**, entry-point which puts everything together including parsing and validating configs, setting up the environment, and correctly using the recipe class +* **Recipe class**, core logic needed for training, exposed through a set of APIs + +.. note:: + + To learn more about the concept of "recipes", check out our technical deep-dive: :ref:`recipe_deepdive`. + +torchtune provides built-in recipes for finetuning on single device, on multiple devices with `FSDP `_, +using memory efficient techniques like `LoRA `_, and more! You can view all built-in recipes `on GitHub `_. You can also utilize the +:ref:`tune ls ` command to print out all recipes and corresponding configs. + +.. TODO (SalmanMohammadi) point to recipe index page here. + +.. code-block:: bash + + $ tune ls + RECIPE CONFIG + full_finetune_single_device llama2/7B_full_low_memory + mistral/7B_full_low_memory + full_finetune_distributed llama2/7B_full + llama2/13B_full + mistral/7B_full + lora_finetune_single_device llama2/7B_lora_single_device + llama2/7B_qlora_single_device + mistral/7B_lora_single_device + ... + +For the purposes of this tutorial, you'll will be using the recipe for finetuning a Llama2 model using `LoRA `_ on +a single device. For a more in-depth discussion on LoRA in torchtune, you can see the complete :ref:`lora_finetune_label` tutorial. + +.. note:: + + **Why have a separate recipe for single device vs. distributed?** This is discussed in + ":ref:`recipe_deepdive`" but one of our :ref:`core principles ` in torchtune is minimal abstraction and boilerplate code. + If you only want to train on a single GPU, our single-device recipe ensures you don't have to worry about additional + features like FSDP that are only required for distributed training. + +| + +.. _tune_cp_label: + +Modifying a config +------------------ +YAML configs hold most of the important information needed for running your recipe. +You can set hyperparameters, specify metric loggers like `WandB `_, select a new dataset, and more. +For a list of all currently supported datasets, see :ref:`datasets`. + +There are two ways to modify an existing config: + +**Override existing parameters from the command line** + +You can override existing parameters from the command line using a :code:`key=value` format. Let's say +you want to set the number of training epochs to 1. + +.. code-block:: bash + + tune run --config epochs=1 + +**Copy the config through `tune cp` and modify directly** + +If you want to make more substantial changes to the config, you can use the :ref:`tune ` CLI to copy it to your local directory. + +.. code-block:: bash + + $ tune cp llama2/7B_lora_single_device custom_config.yaml + Copied file to custom_config.yaml + +Now you can update the custom YAML config any way you like. Try setting the random seed in order to make replication easier, +changing the LoRA rank, update batch size, etc. + +.. note:: + + Check out :ref:`config_tutorial_label` for a deeper dive on configs in torchtune. + +| + +Training a model +---------------- +Now that you have a model in the proper format and a config that suits your needs, let's get training! + +Just like all the other steps, you will be using the :ref:`tune ` CLI tool to launch your finetuning run. + +.. code-block:: bash + + $ tune run lora_finetune_single_device --config llama2/7B_lora_single_device epochs=1 + INFO:torchtune.utils.logging:Running LoRAFinetuneRecipeSingleDevice with resolved config: + Writing logs to /tmp/lora_finetune_output/log_1713194212.txt + INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16. + INFO:torchtune.utils.logging:Tokenizer is initialized from file. + INFO:torchtune.utils.logging:Optimizer and loss are initialized. + INFO:torchtune.utils.logging:Loss is initialized. + INFO:torchtune.utils.logging:Dataset and Sampler are initialized. + INFO:torchtune.utils.logging:Learning rate scheduler is initialized. + 1|52|Loss: 2.3697006702423096: 0%|▏ | 52/25880 [00:24<3:55:01, 1.83it/s] + +You can see that all the modules were successfully initialized and the model has started training. +You can monitor the loss and progress through the `tqdm `_ bar but torchtune +will also log some more metrics, such as GPU memory usage, at an interval defined in the config. + +| + +Next steps +---------- + +Now that you have trained your model and set up your environment, let's take a look at what we can do with our +new model by checking out the :ref:`E2E Workflow Tutorial`. diff --git a/training/docs/source/tutorials/llama3.rst b/training/docs/source/tutorials/llama3.rst new file mode 100644 index 0000000..38f28fa --- /dev/null +++ b/training/docs/source/tutorials/llama3.rst @@ -0,0 +1,263 @@ +.. _llama3_label: + +======================== +Meta Llama3 in torchtune +======================== + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` You will learn how to: + + * Download the Llama3-8B-Instruct weights and tokenizer + * Fine-tune Llama3-8B-Instruct with LoRA and QLoRA + * Evaluate your fine-tuned Llama3-8B-Instruct model + * Generate text with your fine-tuned model + * Quantize your model to speed up generation + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with :ref:`torchtune` + * Make sure to :ref:`install torchtune` + + +Llama3-8B +--------- + +`Meta Llama 3 `_ is a new family of models released by Meta AI that improves upon the performance of the Llama2 family +of models across a `range of different benchmarks `_. +Currently there are two different sizes of Meta Llama 3: 8B and 70B. In this tutorial we will focus on the 8B size model. +There are a few main changes between Llama2-7B and Llama3-8B models: + +- Llama3-8B uses `grouped-query attention `_ instead of the standard multi-head attention from Llama2-7B +- Llama3-8B has a larger vocab size (128,256 instead of 32,000 from Llama2 models) +- Llama3-8B uses a different tokenizer than Llama2 models (`tiktoken `_ instead of `sentencepiece `_) +- Llama3-8B uses a larger intermediate dimension in its MLP layers than Llama2-7B +- Llama3-8B uses a higher base value to calculate theta in its `rotary positional embeddings `_ + +| + +Getting access to Llama3-8B-Instruct +------------------------------------ + +For this tutorial, we will be using the instruction-tuned version of Llama3-8B. First, let's download the model from Hugging Face. You will need to follow the instructions +on the `official Meta page `_ to gain access to the model. +Next, make sure you grab your Hugging Face token from `here `_. + + +.. code-block:: bash + + tune download meta-llama/Meta-Llama-3-8B-Instruct \ + --output-dir \ + --hf-token + +| + +Fine-tuning Llama3-8B-Instruct in torchtune +------------------------------------------- + +torchtune provides `LoRA `_, `QLoRA `_, and full fine-tuning +recipes for fine-tuning Llama3-8B on one or more GPUs. For more on LoRA in torchtune, see our :ref:`LoRA Tutorial `. +For more on QLoRA in torchtune, see our :ref:`QLoRA Tutorial `. + +Let's take a look at how we can fine-tune Llama3-8B-Instruct with LoRA on a single device using torchtune. In this example, we will fine-tune +for one epoch on a common instruct dataset for illustrative purposes. The basic command for a single-device LoRA fine-tune is + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama3/8B_lora_single_device + +.. note:: + To see a full list of recipes and their corresponding configs, simply run ``tune ls`` from the command line. + +We can also add :ref:`command-line overrides ` as needed, e.g. + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama3/8B_lora_single_device \ + checkpointer.checkpoint_dir= \ + tokenizer.path=/tokenizer.model \ + checkpointer.output_dir= + +This will load the Llama3-8B-Instruct checkpoint and tokenizer from ```` used in the :ref:`tune download ` command above, +then save a final checkpoint in the same directory following the original format. For more details on the +checkpoint formats supported in torchtune, see our :ref:`checkpointing deep-dive `. + +.. note:: + To see the full set of configurable parameters for this (and other) configs we can use :ref:`tune cp ` to copy (and modify) + the default config. :ref:`tune cp ` can be used with recipe scripts too, in case you want to make more custom changes + that cannot be achieved by directly modifying existing configurable parameters. For more on :ref:`tune cp ` see the section on + :ref:`modifying configs ` in our ":ref:`finetune_llama_label`" tutorial. + +Once training is complete, the model checkpoints will be saved and their locations will be logged. For +LoRA fine-tuning, the final checkpoint will contain the merged weights, and a copy of just the (much smaller) LoRA weights +will be saved separately. + +In our experiments, we observed a peak memory usage of 18.5 GB. The default config can be trained on a consumer GPU with 24 GB VRAM. + +If you have multiple GPUs available, you can run the distributed version of the recipe. +torchtune makes use of the `FSDP `_ APIs from PyTorch Distributed +to shard the model, optimizer states, and gradients. This should enable you to increase your batch size, resulting in faster overall training. +For example, on two devices: + +.. code-block:: bash + + tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora + +Finally, if we want to use even less memory, we can leverage torchtune's QLoRA recipe via: + +.. TODO (SalmanMohammadi) ref qlora recipe page + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama3/8B_qlora_single_device + +Since our default configs enable full bfloat16 training, all of the above commands can be run with +devices having at least 24 GB of VRAM, and in fact the QLoRA recipe should have peak allocated memory +below 10 GB. You can also experiment with different configurations of LoRA and QLoRA, or even run a full fine-tune. +Try it out! + +| + +Evaluating fine-tuned Llama3-8B models with EleutherAI's Eval Harness +--------------------------------------------------------------------- + +Now that we've fine-tuned our model, what's next? Let's take our LoRA-finetuned model from the +preceding section and look at a couple different ways we can evaluate its performance on the tasks we care about. + +First, torchtune provides an integration with +`EleutherAI's evaluation harness `_ +for model evaluation on common benchmark tasks. + +.. note:: + Make sure you've first installed the evaluation harness via :code:`pip install "lm_eval==0.4.*"`. + +For this tutorial we'll use the `truthfulqa_mc2 `_ task from the harness. +This task measures a model's propensity to be truthful when answering questions and +measures the model's zero-shot accuracy on a question followed by one or more true +responses and one or more false responses. First, let's copy the config so we can point the YAML +file to our fine-tuned checkpoint files. + +.. code-block:: bash + + tune cp eleuther_evaluation ./custom_eval_config.yaml + +Next, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoints. + +.. code-block:: yaml + + model: + _component_: torchtune.models.llama3.llama3_8b + + checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + + # directory with the checkpoint files + # this should match the output_dir specified during + # fine-tuning + checkpoint_dir: + + # checkpoint files for the fine-tuned model. These will be logged + # at the end of your fine-tune + checkpoint_files: [ + meta_model_0.pt + ] + + output_dir: + model_type: LLAMA3 + + # Make sure to update the tokenizer path to the right + # checkpoint directory as well + tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tokenizer.model + +Finally, we can run evaluation using our modified config. + +.. code-block:: bash + + tune run eleuther_eval --config ./custom_eval_config.yaml + +Try it for yourself and see what accuracy your model gets! + +| + +Generating text with our fine-tuned Llama3 model +------------------------------------------------ + +.. TODO (SalmanMohammadi) ref generate recipe page + +Next, let's look at one other way we can evaluate our model: generating text! torchtune provides a +`recipe for generation `_ as well. + +Similar to what we did, let's copy and modify the default generation config. + +.. code-block:: bash + + tune cp generation ./custom_generation_config.yaml + +Now we modify ``custom_generation_config.yaml`` to point to our checkpoint and tokenizer. + +.. code-block:: yaml + + model: + _component_: torchtune.models.llama3.llama3_8b + + checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + + # directory with the checkpoint files + # this should match the output_dir specified during + # fine-tuning + checkpoint_dir: + + # checkpoint files for the fine-tuned model. These will be logged + # at the end of your fine-tune + checkpoint_files: [ + meta_model_0.pt + ] + + output_dir: + model_type: LLAMA3 + + # Make sure to update the tokenizer path to the right + # checkpoint directory as well + tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tokenizer.model + +Running generation with our LoRA-finetuned model, we see the following output: + +.. code-block:: bash + + tune run generate --config ./custom_generation_config.yaml \ + prompt="Hello, my name is" + + [generate.py:122] Hello, my name is Sarah and I am a busy working mum of two young children, living in the North East of England. + ... + [generate.py:135] Time for inference: 10.88 sec total, 18.94 tokens/sec + [generate.py:138] Bandwidth achieved: 346.09 GB/s + [generate.py:139] Memory used: 18.31 GB + +Faster generation via quantization +---------------------------------- + +We rely on `torchao `_ for `post-training quantization `_. +To quantize the fine-tuned model after installing torchao we can run the following command:: + + # we also support `int8_weight_only()` and `int8_dynamic_activation_int8_weight()`, see + # https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + # for a full list of techniques that we support + from torchao.quantization.quant_api import quantize\_, int4_weight_only + quantize\_(model, int4_weight_only()) + +After quantization, we rely on torch.compile for speedups. For more details, please see `this example usage `_. + +torchao also provides `this table `_ listing performance and accuracy results for ``llama2`` and ``llama3``. + +For Llama models, you can run generation directly in torchao on the quantized model using their ``generate.py`` script as +discussed in `this readme `_. This way you can compare your own results +to those in the previously-linked table. + + +This is just the beginning of what you can do with Meta Llama3 using torchtune and the broader ecosystem. +We look forward to seeing what you build! diff --git a/training/docs/source/tutorials/lora_finetune.rst b/training/docs/source/tutorials/lora_finetune.rst new file mode 100644 index 0000000..4a18995 --- /dev/null +++ b/training/docs/source/tutorials/lora_finetune.rst @@ -0,0 +1,381 @@ +.. _lora_finetune_label: + +=========================== +Finetuning Llama2 with LoRA +=========================== + +This guide will teach you about `LoRA `_, a parameter-efficient finetuning technique, +and show you how you can use torchtune to finetune a Llama2 model with LoRA. +If you already know what LoRA is and want to get straight to running +your own LoRA finetune in torchtune, you can jump to :ref:`LoRA finetuning recipe in torchtune`. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * What LoRA is and how it saves memory during finetuning + * An overview of LoRA components in torchtune + * How to run a LoRA finetune using torchtune + * How to experiment with different LoRA configurations + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with :ref:`torchtune` + * Make sure to :ref:`install torchtune` + * Make sure you have downloaded the :ref:`Llama2-7B model weights` + +What is LoRA? +------------- + +`LoRA `_ is an adapter-based method for +parameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network, +then freezes the network's remaining parameters. LoRA is most commonly applied to +transformer models, in which case it is common to add the low-rank matrices +to some of the linear projections in each transformer layer's self-attention. + +.. note:: + + If you're unfamiliar, check out these references for the `definition of rank `_ + and discussion of `low-rank approximations `_. + +By finetuning with LoRA (as opposed to finetuning all model parameters), +you can expect to see memory savings due to a substantial reduction in the +number of parameters with gradients. When using an optimizer with momentum, +like `AdamW `_, +you can expect to see further memory savings from the optimizer state. + +.. note:: + + LoRA memory savings come primarily from gradient and optimizer states, + so if your model's peak memory comes in its :code:`forward()` method, then LoRA + may not reduce peak memory. + +How does LoRA work? +------------------- + +LoRA replaces weight update matrices with a low-rank approximation. In general, weight updates +for an arbitrary :code:`nn.Linear(in_dim,out_dim)` layer could have rank as high as +:code:`min(in_dim,out_dim)`. LoRA (and other related papers such as `Aghajanyan et al. `_) +hypothesize that the `intrinsic dimension `_ +of these updates during LLM fine-tuning can in fact be much lower. +To take advantage of this property, LoRA finetuning will freeze the original model, +then add a trainable weight update from a low-rank projection. More explicitly, LoRA trains two +matrices :code:`A` and :code:`B`. :code:`A` projects the inputs down to a much smaller rank (often four or eight in practice), and +:code:`B` projects back up to the dimension output by the original linear layer. + +The image below gives a simplified representation of a single weight update step from a full finetune +(on the left) compared to a weight update step with LoRA (on the right). The LoRA matrices :code:`A` and :code:`B` +serve as an approximation to the full rank weight update in blue. + +.. image:: /_static/img/lora_diagram.png + +Although LoRA introduces a few extra parameters in the model :code:`forward()`, only the :code:`A` and :code:`B` matrices are trainable. +This means that with a rank :code:`r` LoRA decomposition, the number of gradients we need to store reduces +from :code:`in_dim*out_dim` to :code:`r*(in_dim+out_dim)`. (Remember that in general :code:`r` +is much smaller than :code:`in_dim` and :code:`out_dim`.) + +For example, in the 7B Llama2's self-attention, :code:`in_dim=out_dim=4096` for the Q, K, +and V projections. This means a LoRA decomposition of rank :code:`r=8` will reduce the number of trainable +parameters for a given projection from :math:`4096 * 4096 \approx 15M` to :math:`8 * 8192 \approx 65K`, a +reduction of over 99%. + +Let's take a look at a minimal implementation of LoRA in native PyTorch. + + +.. code-block:: python + + from torch import nn, Tensor + + class LoRALinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float + ): + # These are the weights from the original pretrained model + self.linear = nn.Linear(in_dim, out_dim, bias=False) + + # These are the new LoRA params. In general rank << in_dim, out_dim + self.lora_a = nn.Linear(in_dim, rank, bias=False) + self.lora_b = nn.Linear(rank, out_dim, bias=False) + + # Rank and alpha are commonly-tuned hyperparameters + self.rank = rank + self.alpha = alpha + + # Most implementations also include some dropout + self.dropout = nn.Dropout(p=dropout) + + # The original params are frozen, and only LoRA params are trainable. + self.linear.weight.requires_grad = False + self.lora_a.weight.requires_grad = True + self.lora_b.weight.requires_grad = True + + def forward(self, x: Tensor) -> Tensor: + # This would be the output of the original model + frozen_out = self.linear(x) + + # lora_a projects inputs down to the much smaller self.rank, + # then lora_b projects back up to the output dimension + lora_out = self.lora_b(self.lora_a(self.dropout(x))) + + # Finally, scale by the alpha parameter (normalized by rank) + # and add to the original model's outputs + return frozen_out + (self.alpha / self.rank) * lora_out + +There are some other details around initialization which we omit here, but if you'd like to know more +you can see our implementation in :class:`~torchtune.modules.peft.LoRALinear`. +Now that we understand what LoRA is doing, let's look at how we can apply it to our favorite models. + +Applying LoRA to Llama2 models +------------------------------ + +With torchtune, we can easily apply LoRA to Llama2 with a variety of different configurations. +Let's take a look at how to construct Llama2 models in torchtune with and without LoRA. + +.. code-block:: python + + from torchtune.models.llama2 import llama2_7b, lora_llama2_7b + + # Build Llama2 without any LoRA layers + base_model = llama2_7b() + + # The default settings for lora_llama2_7b will match those for llama2_7b + # We just need to define which layers we want LoRA applied to. + # Within each self-attention, we can choose from ["q_proj", "k_proj", "v_proj", and "output_proj"]. + # We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear + # layers outside of the self-attention. + lora_model = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"]) + +.. note:: + + Calling :func:`lora_llama_2_7b ` alone will not handle the definition of which parameters are trainable. + See :ref:`below` for how to do this. + +Let's inspect each of these models a bit more closely. + +.. code-block:: bash + + # Print the first layer's self-attention in the usual Llama2 model + >>> print(base_model.layers[0].attn) + CausalSelfAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=4096, bias=False) + (v_proj): Linear(in_features=4096, out_features=4096, bias=False) + (output_proj): Linear(in_features=4096, out_features=4096, bias=False) + (pos_embeddings): RotaryPositionalEmbeddings() + ) + + # Print the same for Llama2 with LoRA weights + >>> print(lora_model.layers[0].attn) + CausalSelfAttention( + (q_proj): LoRALinear( + (dropout): Dropout(p=0.0, inplace=False) + (lora_a): Linear(in_features=4096, out_features=8, bias=False) + (lora_b): Linear(in_features=8, out_features=4096, bias=False) + ) + (k_proj): Linear(in_features=4096, out_features=4096, bias=False) + (v_proj): LoRALinear( + (dropout): Dropout(p=0.0, inplace=False) + (lora_a): Linear(in_features=4096, out_features=8, bias=False) + (lora_b): Linear(in_features=8, out_features=4096, bias=False) + ) + (output_proj): Linear(in_features=4096, out_features=4096, bias=False) + (pos_embeddings): RotaryPositionalEmbeddings() + ) + + +Notice that our LoRA model's layer contains additional weights in the Q and V projections, +as expected. Additionally, inspecting the type of :code:`lora_model` and +:code:`base_model`, would show that they are both instances of the same :class:`~torchtune.modules.TransformerDecoder`. +(Feel free to verify this for yourself.) + +Why does this matter? torchtune makes it easy to load checkpoints for LoRA directly from our Llama2 +model without any wrappers or custom checkpoint conversion logic. + +.. code-block:: python + + # Assuming that base_model already has the pretrained Llama2 weights, + # this will directly load them into your LoRA model without any conversion necessary. + lora_model.load_state_dict(base_model.state_dict(), strict=False) + +.. note:: + Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in + the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via e.g. + :func:`validate_state_dict_for_lora() ` or + :func:`validate_missing_and_unexpected_for_lora() `. + +Once we've loaded the base model weights, we also want to set only LoRA parameters to trainable. + +.. _setting_trainable_params: + +.. code-block:: python + + from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params + + # Fetch all params from the model that are associated with LoRA. + lora_params = get_adapter_params(lora_model) + + # Set requires_grad=True on lora_params, and requires_grad=False on all others. + set_trainable_params(lora_model, lora_params) + + # Print the total number of parameters + total_params = sum([p.numel() for p in lora_model.parameters()]) + trainable_params = sum([p.numel() for p in lora_model.parameters() if p.requires_grad]) + print( + f""" + {total_params} total params, + {trainable_params}" trainable params, + {(100.0 * trainable_params / total_params):.2f}% of all params are trainable. + """ + ) + + 6742609920 total params, + 4194304 trainable params, + 0.06% of all params are trainable. + +.. note:: + If you are directly using the LoRA recipe (as detailed :ref:`here`), you need only pass the + relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care + of in the recipe. + + +.. _lora_recipe_label: + +LoRA finetuning recipe in torchtune +----------------------------------- + +Finally, we can put it all together and finetune a model using torchtune's `LoRA recipe `_. +Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions`. +You can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB): + +.. code-block:: bash + + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora + +.. note:: + Make sure to point to the location of your Llama2 weights and tokenizer. This can be done + either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path` + or by directly modifying the :code:`7B_lora.yaml` file. See our "":ref:`config_tutorial_label`" recipe + for more details on how you can easily clone and modify torchtune configs. + +.. note:: + You can modify the value of :code:`nproc_per_node` depending on (a) the number of GPUs you have available, + and (b) the memory constraints of your hardware. + +The preceding command will run a LoRA finetune with torchtune's factory settings, but we may want to experiment a bit. +Let's take a closer look at some of the :code:`lora_finetune_distributed` config. + +.. code-block:: yaml + + # Model Arguments + model: + _component_: lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + lora_rank: 8 + lora_alpha: 16 + ... + +We see that the default is to apply LoRA to Q and V projections with a rank of 8. +Some experiments with LoRA have found that it can be beneficial to apply LoRA to all linear layers in +the self-attention, and to increase the rank to 16 or 32. Note that this is likely to increase our max memory, +but as long as we keep :code:`rank<`_ +floating-point format. This can be done via the command: + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_lora_single_device + +On a single device, we may need to be more cognizant of our peak memory. Let's run a few experiments +to see our peak memory during a finetune. We will experiment along two axes: +first, which model layers have LoRA applied, and second, the rank of each LoRA layer. (We will scale +alpha in parallel to LoRA rank, as discussed above.) + +To compare the results of our experiments, we can evaluate our models on `truthfulqa_mc2 `_, a task from +the `TruthfulQA `_ benchmark for language models. For more details on how to run this and other evaluation tasks +with torchtune's EleutherAI evaluation harness integration, see our :ref:`End-to-End Workflow Tutorial `. + +Previously, we only enabled LoRA for the linear layers in each self-attention module, but in fact there are other linear +layers we can apply LoRA to: MLP layers and our model's final output projection. Note that for Llama-2-7B the final output +projection maps to the vocabulary dimension (32000 instead of 4096 as in the other linear layers), so enabling LoRA for this layer will increase +our peak memory a bit more than the other layers. We can make the following changes to our config: + +.. code-block:: yaml + + # Model Arguments + model: + _component_: lora_llama2_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: True + ... + +.. note:: + All the finetuning runs below use the `llama2/7B_lora_single_device `_ + config, which has a default batch size of 2. Modifying the batch size (or other hyperparameters, e.g. the optimizer) will impact both peak memory + and final evaluation results. + +.. list-table:: + :widths: 25 25 25 25 25 + :header-rows: 1 + + * - LoRA Layers + - Rank + - Alpha + - Peak Memory + - Accuracy (truthfulqa_mc2) + * - Q and V only + - 8 + - 16 + - **15.57 GB** + - 0.475 + * - all layers + - 8 + - 16 + - 15.87 GB + - 0.508 + * - Q and V only + - 64 + - 128 + - 15.86 GB + - 0.504 + * - all layers + - 64 + - 128 + - 17.04 GB + - **0.514** + +We can see that our baseline settings give the lowest peak memory, but our evaluation performance is relatively lower. +By enabling LoRA for all linear layers and increasing the rank to 64, we see almost a 4% absolute improvement +in our accuracy on this task, but our peak memory also increases by about 1.4GB. These are just a couple simple +experiments; we encourage you to run your own finetunes to find the right tradeoff for your particular setup. + +Additionally, if you want to decrease your model's peak memory even further (and still potentially achieve similar +model quality results), you can check out our :ref:`QLoRA tutorial`. diff --git a/training/docs/source/tutorials/qat_finetune.rst b/training/docs/source/tutorials/qat_finetune.rst new file mode 100644 index 0000000..e465763 --- /dev/null +++ b/training/docs/source/tutorials/qat_finetune.rst @@ -0,0 +1,390 @@ +.. _qat_finetune_label: + +=========================== +Finetuning Llama3 with QAT +=========================== + +Quantization-Aware Training (QAT) is a common technique for users to quantize their +models without incurring significant degradations in accuracy or perplexity. In this +tutorial, we’ll walk through how to apply QAT during fine-tuning, quantize the +resulting model, and evaluate your quantized model using torchtune. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * What QAT is and how it helps reduce quantization degradation + * How to run QAT during fine-tuning in torchtune + * End-to-end example of connecting QAT, quantization, and evaluation recipes + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with :ref:`torchtune` + * Make sure to :ref:`install torchtune` + * Make sure you have downloaded the :ref:`Llama3-8B model weights` + +What is QAT? +------------ + +`Quantization-Aware Training `_ (QAT) refers to simulating quantization numerics during +training or fine-tuning, with the end goal of ultimately producing a higher quality +quantized model compared to simple post-training quantization (PTQ). During QAT, +the weights and/or activations are “fake quantized”, meaning they are transformed +as if they were being quantized, but kept in the original data type (e.g. bfloat16) +without being actually cast to lower bit-widths. Thus, fake quantization allows the +model to adjust for quantization noise when updating the weights, hence the training +process is “aware” that the model will ultimately be quantized after training. + +.. code-block:: python + + # PTQ: x_q is quantized and cast to int8 + # scale and zero point (zp) refer to parameters used to quantize x_float + # qmin and qmax refer to the range of quantized values + x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8) + + # QAT: x_fq is still in float + # Fake quantize simulates the numerics of quantize + dequantize + x_fq = (x_float / scale + zp).round().clamp(qmin, qmax) + x_fq = (x_fq - zp) * scale + +QAT typically involves applying a transformation to your model before and after training. +For example, in the `torchao QAT implementation `_, +these are represented as the prepare and convert steps: (1) prepare inserts fake quantize +operations into linear layers, and (2) convert transforms the fake quantize operations +to actual quantize and dequantize operations after training, thereby producing a quantized +model (dequantize operations are typically fused with linear after lowering). +Between these two steps, training can proceed exactly as before. + +.. image:: /_static/img/qat_diagram.png + + +Applying QAT to Llama3 models +----------------------------- + +We can easily apply the above QAT transformations to Llama3 in torchtune for fine-tuning: + +.. code-block:: python + + from torchtune.utils.quantization import Int8DynActInt4WeightQATQuantizer + from torchtune.models.llama3 import llama3_8b + + model = llama3_8b() + + # Quantizer for int8 dynamic per token activations + + # int4 grouped per channel weights, only for linear layers + quantizer = Int8DynActInt4WeightQATQuantizer() + + # Insert "fake quantize" operations into linear layers. + # These operations simulate quantization numerics during + # fine-tuning without performing any dtype casting + prepared_model = quantizer.prepare(model) + +If we print the model we’ll see that all linear layers have been swapped with +:code:`Int8DynActInt4WeightQATLinear`, which simulates the numerics of int8 +dynamic per token activations + int4 grouped per channel weights. Now the model +is ready for fine-tuning. + +.. code-block:: bash + + >>> print(model.layers[0].attn) + CausalSelfAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) + (output_proj): Linear(in_features=4096, out_features=4096, bias=False) + (pos_embeddings): RotaryPositionalEmbeddings() + ) + + >>> print(prepared_model.layers[0].attn) + CausalSelfAttention( + (q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False) + (k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False) + (v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False) + (output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False) + (pos_embeddings): RotaryPositionalEmbeddings() + ) + +After fine-tuning, we can convert the model to get an actual quantized model. +If we print the converted model, we’ll see that the QAT linears have been +swapped with `Int8DynActInt4WeightLinear `_, which are the quantized versions +of the linear layers. This quantized model can then be saved to checkpoint and +used for inference or generation. + +.. code-block:: python + + # Fine-tune as before + train_loop(prepared_model) + + # Convert fake quantize to actual quantize operations + converted_model = quantizer.convert(prepared_model) + +.. code-block:: bash + + >>> print(converted_model.layers[0].attn) + CausalSelfAttention( + (q_proj): Int8DynActInt4WeightLinear() + (k_proj): Int8DynActInt4WeightLinear() + (v_proj): Int8DynActInt4WeightLinear() + (output_proj): Int8DynActInt4WeightLinear() + (pos_embeddings): RotaryPositionalEmbeddings() + ) + + +QAT finetuning recipe in torchtune +---------------------------------- + +Putting it all together, we can now fine-tune a model using torchtune’s `QAT recipe `_. +Make sure that you have first downloaded the Llama3 weights and tokenizer by +following :ref:`these instructions`. In this tutorial, +we use the following settings to demonstrate QAT’s effectiveness in recovering +quantization degradation compared to directly quantizing a model fine-tuned +without QAT. You can copy the default QAT config and make the following +modifications accordingly: + +.. code-block:: bash + + tune cp llama3/8B_qat_full custom_8B_qat_full.yaml + +.. code-block:: yaml + + # Dataset + dataset: + _component_: torchtune.datasets.text_completion_dataset + source: allenai/c4 + max_seq_len: 8192 + column: text + name: en + split: train + seed: null + shuffle: True + + ... + + epochs: 1 + max_steps_per_epoch: 2000 + fake_quant_after_n_steps: 1000 + memory_efficient_fsdp_wrap: False + +.. note:: + + QAT in torchtune is currently not compatible with `memory_efficient_fsdp_wrap `_. + This is a known issue and will be fixed in a future torchtune version. + +Empirically, we observed that disabling fake quantization for the first N steps +led to better results, presumably because doing so allows the weights to stabilize +before we start introducing quantization noise to the fine-tuning process. +For this reason, here we disable fake quantization for the first 1000 steps. + +You can then use the following command to run fine-tuning with QAT using the above +config. This workload requires at least 6 GPUs, each with VRAM of at least 80GB. +By default, this uses the int8 dynamic per token activations + int4 grouped per +channel weights quantization configuration as shown above: + +.. code-block:: bash + + tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config custom_8B_qat_full.yaml + +.. note:: + + Make sure to point to the location of your Llama3 weights and tokenizer. This can be done + either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path` + or by directly modifying the :code:`8B_qat_full.yaml` file. See our :ref:`config_tutorial_label` + for more details on how you can easily clone and modify torchtune configs. + +.. note:: + + QAT introduces memory and computation overheads compared to regular fine-tuning, + since fake quantization fundamentally involves extra ops and requires cloning + the weights to avoid mutating them when computing the fake quantized values. + In general, we expect around 30% decrease in fine-tuning speed for models like + Llama3-8B. With activation checkpointing, the increase in memory footprint per + GPU is minimal (< 5GB per GPU). + + +Quantizing the QAT model +------------------------ + +Note that the QAT recipe above produces an unquantized bfloat16 model. The model +structure is exactly the same as the model produced with regular full fine-tuning +without QAT, just with different weights. To actually get a quantized model, +copy and make the following modifications to the quantization config: + +.. code-block:: bash + + tune cp quantization custom_quantization.yaml + +.. code-block:: yaml + + # Model arguments + model: + _component_: torchtune.models.llama3.llama3_8b + + checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: + checkpoint_files: [meta_model_0.pt] + recipe_checkpoint: null + output_dir: + model_type: LLAMA3 + + ... + + quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 + +The following command performs the convert step in the QAT flow, which actually +quantizes the float model to a model with quantized weights: + +.. code-block:: bash + + tune run quantize --config custom_quantization.yaml + +.. note:: + + Make sure to use the same QAT quantizer you used to fine-tune your model, + otherwise the numerics will be off and the quantized model will perform poorly. + +Evaluating the quantized model +------------------------------ + +Now that we have a quantized model, we can run some evaluations on it and compare the +results against regular fine-tuning without QAT (i.e. post-training quantization). +To achieve this, we use `EleutherAI’s evaluation harness `_ +integrated in torchtune. First, copy the evaluation config and make the following changes: + +.. code-block:: bash + + tune cp eleuther_evaluation custom_eleuther_evaluation.yaml + +.. code-block:: yaml + + # Model arguments + model: + _component_: torchtune.models.llama3.llama3_8b + + checkpointer: + _component_: torchtune.utils.FullModelTorchTuneCheckpointer + checkpoint_dir: + checkpoint_files: [meta_model_0-8da4w.pt] + recipe_checkpoint: null + output_dir: + model_type: LLAMA3 + + ... + + # EleutherAI specific eval args + tasks: ["hellaswag", "wikitext"] + limit: null + max_seq_length: 8192 + batch_size: 8 + + quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer + groupsize: 256 + +.. note:: + + Since we are passing in a quantized model, be sure to use the corresponding + post-training quantizer instead of the QAT quantizer. For example, if you + used the :code:`Int8DynActInt4WeightQATQuantizer` during fine-tuning, you + should specify :code:`Int8DynActInt4WeightQuantizer` in this step. See the + `quantization recipe `_ + for a full list of supported quantizers. + +Now run the evaluation recipe: + +.. code-block:: bash + + tune run eleuther_eval --config my_eleuther_evaluation.yaml + +The results should look something like this: + +.. code-block:: bash + + # QAT quantized model evaluation results (int8 activations + int4 weights) + + | Tasks |Version|Filter|n-shot| Metric |Value | |Stderr| + |---------|------:|------|-----:|---------------|-----:|---|------| + |wikitext | 2|none | 0|word_perplexity|9.9148|± |N/A | + | | |none | 0|byte_perplexity|1.5357|± |N/A | + | | |none | 0|bits_per_byte |0.6189|± |N/A | + |hellaswag| 1|none | 0|acc |0.5687|± |0.0049| + | | |none | 0|acc_norm |0.7536|± |0.0043| + +Comparing these results to the model fine-tuned without QAT, we can see that +QAT was able to recover a significant portion of the quantization degradations +from the original unquantized model compared to PTQ. For example, normalized +accuracy in the hellaswag task dropped by 2.20% with PTQ but only 0.74% with +QAT when compared to the original unquantized model. Similarly, word perplexity +in the wikitext task increased by 2.048 with PTQ but only 1.190 with QAT (lower +is better). + +.. code-block:: bash + + # PTQ quantized model evaluation results (int8 activations + int4 weights) + + | Tasks |Version|Filter|n-shot| Metric | Value | |Stderr| + |---------|------:|------|-----:|---------------|------:|---|------| + |wikitext | 2|none | 0|word_perplexity|10.7735|± |N/A | + | | |none | 0|byte_perplexity| 1.5598|± |N/A | + | | |none | 0|bits_per_byte | 0.6413|± |N/A | + |hellaswag| 1|none | 0|acc | 0.5481|± |0.0050| + | | |none | 0|acc_norm | 0.7390|± |0.0044| + +.. code-block:: bash + + # Float model evaluation results (bfloat16) + + | Tasks |Version|Filter|n-shot| Metric |Value | |Stderr| + |---------|------:|------|-----:|---------------|-----:|---|------| + |wikitext | 2|none | 0|word_perplexity|8.7251|± |N/A | + | | |none | 0|byte_perplexity|1.4994|± |N/A | + | | |none | 0|bits_per_byte |0.5844|± |N/A | + |hellaswag| 1|none | 0|acc |0.5740|± |0.0049| + | | |none | 0|acc_norm |0.7610|± |0.0043| + +Thus, the QAT flow produced a quantized model that outperforms the post-training +quantized model. Importantly, the quantized model structure is identical in both +flows, and so the model size, memory usage, and all other performance +characteristics are also the same. + +Note that although the weights are quantized to int4, the quantized model size +for both the QAT and the PTQ flows are 8.187 GB, while the original float model +is 14.958 GB. This is because this quantizer uses int8 to represent the weights +as PyTorch does not have native int4 dtype support. A more efficient representation +is to pack the int4 weights, which will halve the quantized model size. This is +what the Int4WeightOnlyQuantizer does, and the corresponding QAT quantizer will +be added in the future. + +Lowering QAT model to device (optional) +--------------------------------------- + +One important motivation for quantizing a model is to be able to run it in resource +constrained environments. You can further lower your QAT Llama3 model to edge devices +such as smartphones using `executorch `_ by +following `these instructions `_. +For example, the following command lowers the model to the XNNPACK backend: + +.. code-block:: bash + + python -m examples.models.llama2.export_llama --checkpoint -p -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 256 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_id":128001}' --embedding-quantize 4,32 --output_name="llama3_8da4w.pte" + +This results in a much smaller quantized model of size 3.881 GB. When benchmarked on a OnePlus 12 smartphone, this model also achieved the same inference and generation speeds as the post-training quantized model. This is because the model structures are the same across the two flows: + +.. list-table:: + :widths: 25 25 25 + :header-rows: 1 + + * - + - QAT + - PTQ + * - Quantized model size + - 3.881 GB + - 3.881 GB + * - Inference speed + - 9.709 tok/s + - 9.815 tok/s + * - Generation speed + - 11.316 tok/s + - 11.364 tok/s diff --git a/training/docs/source/tutorials/qlora_finetune.rst b/training/docs/source/tutorials/qlora_finetune.rst new file mode 100644 index 0000000..789fc65 --- /dev/null +++ b/training/docs/source/tutorials/qlora_finetune.rst @@ -0,0 +1,281 @@ +.. _qlora_finetune_label: + +============================= +Finetuning Llama2 with QLoRA +============================= + +In this tutorial, we'll learn about `QLoRA `_, an enhancement on top of +`LoRA `_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll +walk through how QLoRA can be utilized within torchtune to finetune a Llama2-7b model in <10 GB of memory. +It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in torchtune`. + + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How QLoRA saves memory over LoRA finetuning + * An overview of QLoRA in torchtune + * How to run a QLoRA finetune in torchtune + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with :ref:`torchtune` + * Make sure to :ref:`install torchtune` + * Make sure you have downloaded the :ref:`Llama2-7B model weights` + * Be familiar with :ref:`LoRA in torchtune` + +What is QLoRA? +--------------- + +QLoRA builds on top of LoRA to enable further +memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are +low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of +the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or `bf16 `_.), and +therefore activations and intermediate gradients computed are in fp32/bf16. + +QLoRA further quantizes the base model parameters into a bespoke 4-bit NormalFloat (`NF4 `_) data type, resulting in 4-8x less parameter memory usage while +largely retaining model accuracy. As a result, the vast majority of parameters only take up 4 bits (as opposed to 16 or 32 bits by bf16/fp32 dtypes). This +quantization is done through the method highlighted in the original `QLoRA paper `_. Adapter +parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve +accuracy. + +The QLoRA authors introduce two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat +type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. torchtune uses +the `NF4Tensor `_ abstraction from the `torchao library `_ to build QLoRA components as specified in the paper. +torchao is a PyTorch-native library that allows you to quantize and prune your models. + + +.. _qlora_core_highlevel: + +Using QLoRA to save memory +---------------------------------------- + +In this section, we'll overview how to apply QLoRA to a :class:`~torchtune.modules.peft.LoRALinear` layer in torchtune. For a deep dive into details on QLoRA in torchtune and underlying abstractions, +please see the :ref:`QLoRA in torchtune deepdive ` section of this tutorial. + +A core idea of QLoRA is the distinction between compute and storage datatypes (dtypes). Specifically, QLoRA stores base model parameters in 4-bit precision (i.e. the storage dtype), and runs +computation in an original higher precision (the compute dtype), generally either fp32 or bf16. As a first step, QLoRA needs to quantize these base model parameters to 4-bit precision +and store them. + +To quantize a :class:`~torchtune.modules.peft.LoRALinear` layer in the QLoRA style, simply pass in the ``quantize_base`` flag as ``True`` into :class:`~torchtune.modules.peft.LoRALinear`. This flag +will result in base model weights being quantized and backed by the ``NF4Tensor`` dtype. Forward passes will also be automatically handled to work with the ``NF4Tensor`` dtype, +specifically, the ``NF4`` base weight will be de-quantized to the compute precision, activation will be computed, and only the 4-bit parameter will be stored for gradient computation +in the backward pass, avoiding extra memory usage that would be incurred by storing the higher precision compute dtype. + +Here's an example of creating a quantized ``LoRALinear`` layer in comparison to an unquantized ``LoRALinear`` layer. As we can see, the quantized layer consumes +~8x less memory than the unquantized counterpart. + +.. code-block:: python + + import torch + from torchtune.modules.peft import LoRALinear + + torch.set_default_device("cuda") + qlora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=True) + print(torch.cuda.memory_allocated()) # 177,152 bytes + del qlora_linear + torch.cuda.empty_cache() + lora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=False) + print(torch.cuda.memory_allocated()) # 1,081,344 bytes + + +Using QLoRA in torchtune +---------------------------- + +We'll now cover how you can initialize a QLoRA-enabled Llama2-7b model as well as some details around +checkpointing with QLoRA. + +With torchtune, you can use a simple builder similar to the LoRA builder (:func:`lora_llama_2_7b `) to apply QLoRA to Llama2 models. Here's a simple example of +initializing a Llama2-7b model with QLoRA enabled: + +.. code-block:: python + + from torchtune.models.llama2 import qlora_llama2_7b + + qlora_model = qlora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"]) + +Under the hood, this will apply LoRA to the ``q_proj`` and ``v_proj`` matrices in all attention layers, and further quantize the base parameters +in these matrices to the ``NF4`` dtype. Note that quantization of base model parameters is only applied to layers that are configured to have +LoRA adapters added. For example, in this case, ``k_proj`` and ``output_proj`` in the attention layers don't have LoRA applied to them, so their +base model parameters are not quantized. We can see this by printing the base model parameter dtypes for a particular attention layer: + +.. code-block:: python + + attn = qlora_model.layers[0].attn + print(type(attn.q_proj.weight)) # + print(type(attn.k_proj.weight)) # + + +Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models. +To integrate well with torchtune's :ref:`checkpointing `, we need to convert ``NF4Tensors`` back to their +original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within +torchtune and beyond (e.g. post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done +in a typical LoRA training flow. + +To achieve this, when using torchtune's :func:`lora_llama_2_7b ` builder, we automatically register a hook, +:func:`reparametrize_as_dtype_state_dict_post_hook `, +that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these +converted tensors to the CPU. This offloading is to avoid peaking memory; if we did not, we would have to maintain an entire bf16/fp32 copy of the ``state_dict`` +on GPU. + + + +Putting it all together: QLoRA finetune +----------------------------------------- + +.. TODO (SalmanMohammadi) ref lora recipe w qlora conf. + +Putting it all together, we can now finetune a model using torchtune's `LoRA recipe `_, +with a `QLoRA configuration `_. + +Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions`. +You can then run the following command to perform a QLoRA finetune of Llama2-7B on a single GPU. + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_qlora_single_device + +.. note:: + Make sure to correctly point to the location of your Llama2 weights and tokenizer. This can be done + either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path` + or by directly modifying the :code:`7B_qlora_single_device.yaml` file. See our ":ref:`config_tutorial_label`" recipe + for more details on how you can easily clone and modify torchtune configs. + +By default, this run should log peak memory stats at model initialization time and every 100 +iterations during training. Let's understand the memory savings enabled by QLoRA on top of LoRA training. LoRA training +can be run as follows: + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_lora_single_device + +You should see the memory usage printed out during model initialization and training. An example log for LoRA model initialization is as follows: + +.. code-block:: python + + Memory Stats after model init:: + GPU peak memory allocation: 13.96 GB + GPU peak memory reserved: 13.98 GB + GPU peak memory active: 13.96 GB + +The following table compares the QLoRA's memory reserved during model initialization and training against vanilla LoRA's. +We can see that QLoRA reduces peak memory by about 35% during model initialization, and about 40% during model training: + +================== ================================== ================================ +Finetuning method Peak memory reserved, model init Peak memory reserved, training +================== ================================== ================================ +LoRA 13.98 GB 15.57 GB +QLoRA 9.13 GB 9.29 GB +================== ================================== ================================ + +From the logs, one can see that the out-of-the-box training performance is quite slow, slower than 1 iteration per +second: + +.. code-block:: python + + 1|149|Loss: 0.9157477021217346: 1%| | 149/25880 [02:08<6:14:19, 1.15it/s + +To speed things up, we can leverage `torch.compile `_ to compile our model and run the compiled result. To work with +QLoRA training, a nightly build of PyTorch must be used. To update PyTorch to the latest nightly, +please see `the installation instructions `_. Once updated, +you can specify the compile flag as ``True`` via a config override: + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_qlora_single_device compile=True + +From the logs, we can see about a 200% speed up (after a few hundred iterations once the training has stabilized): + +.. code-block:: python + + 1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s + +A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below. + +.. image:: /_static/img/qlora_exp.png + +.. note:: + The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.utils.metric_logging.WandBLogger` + to generate similar loss curves, but you will need to install W&B and setup an account separately. For more details on + using W&B in torchtune, see our ":ref:`wandb_logging`" recipe. + +As an exercise, you can also try running some evaluation tasks or manually inspecting generations +output by your saved checkpoints (which can be found in :code:`output_dir`). + +In the final section, we'll go over a deep dive on how a QLoRA component can be built from a LoRA component. + +.. _qlora_deepdive_label: + +Deep-dive: Building QLoRA from LoRA +----------------------------------------- + +This deep-dive section resumes from the :ref:`Using QLoRA to save memory` portion of this tutorial and dives into how quantization is done with ``NF4Tensor`` and handled appropriately in the forward pass. + +First, we'll begin with +a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial ` and augmented to support quantization: + +.. code-block:: python + :emphasize-lines: 3, 13, 19, 20, 39, 40, 41 + + from torch import nn, Tensor + import torch.nn.functional as F + from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + class LoRALinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float, + quantize_base: bool + ): + # These are the weights from the original pretrained model + self.linear = nn.Linear(in_dim, out_dim, bias=False) + self.linear_weight = self.linear.weight + # Use torchao's to_nf4 API to quantize the base weight if needed. + if quantize_base: + self.linear_weight = to_nf4(self.linear_weight) + # These are the new LoRA params. In general rank << in_dim, out_dim + self.lora_a = nn.Linear(in_dim, rank, bias=False) + self.lora_b = nn.Linear(rank, out_dim, bias=False) + + # Rank and alpha are commonly-tuned hyperparameters + self.rank = rank + self.alpha = alpha + + # Most implementations also include some dropout + self.dropout = nn.Dropout(p=dropout) + + # The original params are frozen, and only LoRA params are trainable. + self.linear.weight.requires_grad = False + self.lora_a.weight.requires_grad = True + self.lora_b.weight.requires_grad = True + + def forward(self, x: Tensor) -> Tensor: + # frozen_out would be the output of the original model + if quantize_base: + # Call into torchao's linear_nf4 to run linear forward pass w/quantized weight. + frozen_out = linear_nf4(x, self.weight) + else: + frozen_out = F.linear(x, self.weight) + + # lora_a projects inputs down to the much smaller self.rank, + # then lora_b projects back up to the output dimension + lora_out = self.lora_b(self.lora_a(self.dropout(x))) + + # Finally, scale by the alpha parameter (normalized by rank) + # and add to the original model's outputs + return frozen_out + (self.alpha / self.rank) * lora_out + +As mentioned above, torchtune takes a dependency on torchao for some of the core components required for QLoRA. This includes the +``NF4Tensor``, as well as helpful utilities including ``to_nf4`` and ``linear_nf4``. + +The key changes on top of the LoRA layer are the usage of the ``to_nf4`` and ``linear_nf4`` APIs. + +``to_nf4`` accepts an unquantized (bf16 or fp32) tensor and produces an ``NF4`` representation of the weight. See the `implementation `_ of ``to_nf4`` for more details. +``linear_nf4`` handles the forward pass and autograd when running with quantized base model weights. It computes the forward pass as a regular +``F.linear`` with the incoming activation and unquantized weight. The quantized weight is saved for backward, as opposed to the unquantized version of the weight, to avoid extra +memory usage due to storing higher precision variables to compute gradients in the backward pass. See `linear_nf4 `_ for more details. diff --git a/training/pyproject.toml b/training/pyproject.toml new file mode 100644 index 0000000..8f840d1 --- /dev/null +++ b/training/pyproject.toml @@ -0,0 +1,89 @@ +# ---- All project specifications ---- # +[project] +name = "torchtune" +description = "A native-PyTorch library for LLM fine-tuning" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +authors = [ + { name = "PyTorch Team", email = "packages@pytorch.org" }, +] +keywords = ["pytorch", "finetuning", "llm"] +dependencies = [ + + # Hugging Face integrations + "datasets", + "huggingface_hub", + "safetensors", + + # Tokenization + "sentencepiece", + "tiktoken", + "blobfile>=2", + + # Miscellaneous + "numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed + "tqdm", + "omegaconf", + + # Quantization + "torchao==0.3.1", +] +dynamic = ["version"] + +[project.urls] +GitHub = "https://github.com/pytorch/torchtune" +Documentation = "https://pytorch.org/torchtune/main/index.html" +Issues = "https://github.com/pytorch/torchtune/issues" + +[project.scripts] +tune = "torchtune._cli.tune:main" + +[project.optional-dependencies] +dev = [ + "bitsandbytes>=0.43.0", + "pre-commit", + "pytest==7.4.0", + "pytest-cov", + "pytest-mock", + "pytest-integration", + "tensorboard", + "wandb", + "expecttest", +] + +[tool.setuptools.dynamic] +version = {attr = "torchtune.__version__"} + + +# ---- Explicit project build information ---- # +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = [""] +include = ["torchtune*", "recipes*"] + +[tool.setuptools.package-data] +recipes = ["configs/*.yaml", "configs/*/*.yaml", "configs/*/*/*.yaml"] + + +# ---- Tooling specifications ---- # +[tool.usort] +first_party_detection = false + +[tool.black] +target-version = ["py38"] + +[tool.pydoclint] +style = 'google' +check-return-types = 'False' +exclude = 'tests/torchtune/models/(\w+)/scripts/' + +[tool.pytest.ini_options] +addopts = ["--showlocals", "--import-mode=prepend", "--without-integration", "--without-slow-integration"] +# --showlocals will show local variables in tracebacks +# --import-mode=prepend will add the root (the parent dir of torchtune/, tests/, recipes/) +# to `sys.path` when invoking pytest, allowing us to treat `tests` as a package within the tests. +# --without-integration and --without-slow-integration: default to running unit tests only diff --git a/training/recipes/__init__.py b/training/recipes/__init__.py new file mode 100644 index 0000000..8ed6c84 --- /dev/null +++ b/training/recipes/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file mainly exists because we want to ensure that `recipes` aren't +# importable *from the tests*. +# We're using the `prepend` pytest import mode which adds the root dir (i.e. the +# parent of torchtune/, tests/, recipes/) to the pythonpath during pytest +# sessions +# (https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html#import-modes). +# This has the positive effect that the `tests` folder becomes importable when +# testing (we need that, considering how tests are currently set up) but ALSO +# has the negative effect of making the `recipes/` importable when testing. +# Since we don't want the tests to to incorrectly assume that recipes are +# importable, we have to explicitly raise an error here. + +raise ModuleNotFoundError( + "The torchtune recipes directory isn't a package and you should not import anything from here. " + "Refer to our docs for detailed instructions on how to use recipes: " + "https://pytorch.org/torchtune/main/deep_dives/recipe_deepdive.html" +) diff --git a/training/recipes/configs/code_llama2/7B_full_low_memory.yaml b/training/recipes/configs/code_llama2/7B_full_low_memory.yaml new file mode 100644 index 0000000..c78b635 --- /dev/null +++ b/training/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -0,0 +1,79 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.code_llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 2e-5 +optimizer_in_bwd: True +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/code_llama2_finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/code_llama2/7B_lora_single_device.yaml b/training/recipes/configs/code_llama2/7B_lora_single_device.yaml new file mode 100644 index 0000000..274d926 --- /dev/null +++ b/training/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -0,0 +1,109 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.lora_code_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + + +# Training environment +device: cuda +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/lora_code_llama2_finetune_output +log_every_n_steps: 1 +log_peak_memory_stats: False + +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/training/recipes/configs/code_llama2/7B_qlora_single_device.yaml new file mode 100644 index 0000000..c2a1491 --- /dev/null +++ b/training/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -0,0 +1,111 @@ +# Config for single device QLoRA finetuning in lora_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.qlora_code_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Fine-tuning arguments and training +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + + +# Training environment +device: cuda +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/qlora_code_llama2_finetune_output +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/dev/8B_full_experimental.yaml b/training/recipes/configs/dev/8B_full_experimental.yaml new file mode 100644 index 0000000..ea73b85 --- /dev/null +++ b/training/recipes/configs/dev/8B_full_experimental.yaml @@ -0,0 +1,80 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3 8B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B --output-dir /tmp/Meta-Llama-3-8B --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 8B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_8b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: False +ac_mode: 'selective' # ['selective', 'full'] +ac_option: 2 # [int] = ac every positive int layer +memory_efficient_fsdp_wrap: False + + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama3-finetune +log_every_n_steps: null diff --git a/training/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml b/training/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml new file mode 100644 index 0000000..f6cbf97 --- /dev/null +++ b/training/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml @@ -0,0 +1,90 @@ +# Config for multi-device LoRA with FSDP2 in lora_finetune_fsdp2.py +# using a Llama2 13B model +# +# This config requires PyTorch nightlies to run. +# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md +# for setup instructions. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-13b-hf --output-dir /tmp/Llama-2-13b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_fsdp2 --config llama2/13B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_fsdp2 --config llama2/13B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml and update the model and checkpoints to +# the 13B model. + + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_13b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: True + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-13b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-13b-hf/ + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-13b-hf/tokenizer.model + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 2e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/training/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml b/training/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml new file mode 100644 index 0000000..c1dd065 --- /dev/null +++ b/training/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml @@ -0,0 +1,90 @@ +# Config for multi-device LoRA with FSDP2 lora_finetune_fsdp2.py +# using a Llama2 70B model +# +# This config requires PyTorch nightlies to run. +# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md +# for setup instructions. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token +# +# This config needs 8 GPUs to run +# # tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/70B_lora +# + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_70b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 16 + lora_alpha: 32 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-70b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-70b-hf + checkpoint_files: [ + pytorch_model-00001-of-00015.bin, + pytorch_model-00002-of-00015.bin, + pytorch_model-00003-of-00015.bin, + pytorch_model-00004-of-00015.bin, + pytorch_model-00005-of-00015.bin, + pytorch_model-00006-of-00015.bin, + pytorch_model-00007-of-00015.bin, + pytorch_model-00008-of-00015.bin, + pytorch_model-00009-of-00015.bin, + pytorch_model-00010-of-00015.bin, + pytorch_model-00011-of-00015.bin, + pytorch_model-00012-of-00015.bin, + pytorch_model-00013-of-00015.bin, + pytorch_model-00014-of-00015.bin, + pytorch_model-00015-of-00015.bin, + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-70b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml b/training/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml new file mode 100644 index 0000000..8fee1e7 --- /dev/null +++ b/training/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml @@ -0,0 +1,95 @@ +# Config for multi-device QLoRA in lora_finetune_fsdp2.py +# using a Llama2 70B model +# +# This config requires PyTorch nightlies to run. +# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md +# for setup instructions. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token +# +# This config needs 8 GPUs to run +# # tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/70B_qlora +# + +# Model Arguments +model: + _component_: torchtune.models.llama2.qlora_llama2_70b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 16 + lora_alpha: 32 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-70b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-70b-hf + checkpoint_files: [ + pytorch_model-00001-of-00015.bin, + pytorch_model-00002-of-00015.bin, + pytorch_model-00003-of-00015.bin, + pytorch_model-00004-of-00015.bin, + pytorch_model-00005-of-00015.bin, + pytorch_model-00006-of-00015.bin, + pytorch_model-00007-of-00015.bin, + pytorch_model-00008-of-00015.bin, + pytorch_model-00009-of-00015.bin, + pytorch_model-00010-of-00015.bin, + pytorch_model-00011-of-00015.bin, + pytorch_model-00012-of-00015.bin, + pytorch_model-00013-of-00015.bin, + pytorch_model-00014-of-00015.bin, + pytorch_model-00015-of-00015.bin, + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-70b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +fsdp: + cpu_offload: False + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Logging +output_dir: /tmp/qlora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml b/training/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml new file mode 100644 index 0000000..cb28236 --- /dev/null +++ b/training/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml @@ -0,0 +1,87 @@ +# Config for multi-device LoRA finetuning with FSDP2 in lora_finetune_fsdp2.py +# using a Llama2 7B model +# +# This config requires PyTorch nightlies to run. +# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md +# for setup instructions. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml + + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/training/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml b/training/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml new file mode 100644 index 0000000..4e23f07 --- /dev/null +++ b/training/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml @@ -0,0 +1,91 @@ +# Config for single device QLoRA with lora_finetune_fsdp2.py +# using a Llama2 7B model +# +# This config requires PyTorch nightlies to run. +# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md +# for setup instructions. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_qlora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_qlora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml + +# Model Arguments +model: + _component_: torchtune.models.llama2.qlora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + train_on_input: True +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +fsdp: + cpu_offload: False + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +compile: False + +# Logging +output_dir: /tmp/qlora_finetune_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/eleuther_evaluation.yaml b/training/recipes/configs/eleuther_evaluation.yaml new file mode 100644 index 0000000..0cebe5e --- /dev/null +++ b/training/recipes/configs/eleuther_evaluation.yaml @@ -0,0 +1,37 @@ +# Config for EleutherEvalRecipe in eleuther_eval.py +# +# To launch, run the following command from root torchtune directory: +# tune run eleuther_eval --config eleuther_evaluation tasks=["truthfulqa_mc2","hellaswag"] + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin, + ] + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +# Environment +device: cuda +dtype: bf16 +seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed + +# EleutherAI specific eval args +tasks: ["truthfulqa_mc2"] +limit: null +max_seq_length: 4096 +batch_size: 8 + +# Quantization specific args +quantizer: null diff --git a/training/recipes/configs/gemma/2B_full.yaml b/training/recipes/configs/gemma/2B_full.yaml new file mode 100644 index 0000000..d98b13a --- /dev/null +++ b/training/recipes/configs/gemma/2B_full.yaml @@ -0,0 +1,73 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2b --ignore-patterns None --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma/2B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma/2B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.gemma_2b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2b/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2b + model_type: GEMMA +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/gemma/2B_lora.yaml b/training/recipes/configs/gemma/2B_lora.yaml new file mode 100644 index 0000000..11ce335 --- /dev/null +++ b/training/recipes/configs/gemma/2B_lora.yaml @@ -0,0 +1,83 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2b --ignore-patterns None --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/2B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/2B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.lora_gemma_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2b/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2b + model_type: GEMMA +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-lora +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/gemma/2B_lora_single_device.yaml b/training/recipes/configs/gemma/2B_lora_single_device.yaml new file mode 100644 index 0000000..67f577e --- /dev/null +++ b/training/recipes/configs/gemma/2B_lora_single_device.yaml @@ -0,0 +1,109 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2b --ignore-patterns None --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma/2B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma/2B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.lora_gemma_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2b/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2b + model_type: GEMMA +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/gemma/2B_qlora_single_device.yaml b/training/recipes/configs/gemma/2B_qlora_single_device.yaml new file mode 100644 index 0000000..3a5edfd --- /dev/null +++ b/training/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -0,0 +1,109 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2b --ignore-patterns None --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma/2B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma/2B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.qlora_gemma_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2b/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2b + model_type: GEMMA +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/gemma/7B_full.yaml b/training/recipes/configs/gemma/7B_full.yaml new file mode 100644 index 0000000..87b6821 --- /dev/null +++ b/training/recipes/configs/gemma/7B_full.yaml @@ -0,0 +1,75 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-7b --hf-token --ignore-patterns "gemma-7b.gguf" +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma/7B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma/7B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-7b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.gemma_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-7b/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-7b + model_type: GEMMA +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/gemma/7B_lora.yaml b/training/recipes/configs/gemma/7B_lora.yaml new file mode 100644 index 0000000..fb78a3e --- /dev/null +++ b/training/recipes/configs/gemma/7B_lora.yaml @@ -0,0 +1,85 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-7b --hf-token --ignore-patterns "gemma-7b.gguf" +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/7B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/7B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-7b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.lora_gemma_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-7b/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-7b/ + model_type: GEMMA +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-lora +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/gemma/7B_lora_single_device.yaml b/training/recipes/configs/gemma/7B_lora_single_device.yaml new file mode 100644 index 0000000..7f9fd7e --- /dev/null +++ b/training/recipes/configs/gemma/7B_lora_single_device.yaml @@ -0,0 +1,111 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma 7B model +# +# This config assumes that you've run the following command before launching +# this run (torchtune does not use gguf so you can ignore it to save time and space): +# tune download google/gemma-7b --hf-token --ignore-patterns "gemma-7b.gguf" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma/7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-7b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.lora_gemma_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-7b/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-7b/ + model_type: GEMMA +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 5e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 8 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/gemma/7B_qlora_single_device.yaml b/training/recipes/configs/gemma/7B_qlora_single_device.yaml new file mode 100644 index 0000000..08f05bc --- /dev/null +++ b/training/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -0,0 +1,111 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-7b --hf-token --ignore-patterns "gemma-7b.gguf" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma/7B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma/7B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-7b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma.qlora_gemma_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-7b/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-7b/ + model_type: GEMMA +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/generation.yaml b/training/recipes/configs/generation.yaml new file mode 100644 index 0000000..da8d7b4 --- /dev/null +++ b/training/recipes/configs/generation.yaml @@ -0,0 +1,40 @@ +# Config for running the InferenceRecipe in generate.py to generate output from an LLM +# +# To launch, run the following command from root torchtune directory: +# tune run generate --config generation + +# Model arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin, + ] + output_dir: /tmp/Llama-2-7b-hf/ + model_type: LLAMA2 + +device: cuda +dtype: bf16 + +seed: 1234 + +# Tokenizer arguments +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +# Generation arguments; defaults taken from gpt-fast +prompt: "Tell me a joke?" +instruct_template: null +chat_format: null +max_new_tokens: 300 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 +# It is recommended to set enable_kv_cache=False for long-context models like Llama3.1 +enable_kv_cache: True + +quantizer: null diff --git a/training/recipes/configs/llama2/13B_full.yaml b/training/recipes/configs/llama2/13B_full.yaml new file mode 100644 index 0000000..6ad7108 --- /dev/null +++ b/training/recipes/configs/llama2/13B_full.yaml @@ -0,0 +1,75 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama2 13B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-13b-hf --output-dir /tmp/Llama-2-13b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/13B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/13B_full checkpointer.checkpoint_dir= +# +# This config should be used with 2+ GPUs. Single device full fine-tuning +# requires several memory optimizations which are exposed through +# 7B_full_single_device.yaml. Please update the model and checkpoints to 13B +# in that config. + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_13b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-13b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-13b-hf/ + model_type: LLAMA2 +resume_from_checkpoint: False + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-13b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama2/13B_lora.yaml b/training/recipes/configs/llama2/13B_lora.yaml new file mode 100644 index 0000000..f31229f --- /dev/null +++ b/training/recipes/configs/llama2/13B_lora.yaml @@ -0,0 +1,86 @@ +# Config for multi-device LoRA in lora_finetune_distributed.py +# using a Llama2 13B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-13b-hf --output-dir /tmp/Llama-2-13b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml and update the model and checkpoints to +# the 13B model. + + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_13b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: True + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-13b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-13b-hf/ + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-13b-hf/tokenizer.model + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 2e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/training/recipes/configs/llama2/13B_qlora_single_device.yaml b/training/recipes/configs/llama2/13B_qlora_single_device.yaml new file mode 100644 index 0000000..ac0bcab --- /dev/null +++ b/training/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -0,0 +1,108 @@ +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Llama2 13B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-13b-hf --output-dir /tmp/Llama-2-13b-hf --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama2/13B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config 13_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama2.qlora_llama2_13b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-13b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-13b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-13b-hf/ + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/qlora_finetune_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama2/70B_lora.yaml b/training/recipes/configs/llama2/70B_lora.yaml new file mode 100644 index 0000000..ced6edb --- /dev/null +++ b/training/recipes/configs/llama2/70B_lora.yaml @@ -0,0 +1,86 @@ +# Config for multi-device LoRA in lora_finetune_distributed.py +# using a Llama2 70B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token +# +# This config needs 8 GPUs to run +# # tune run --nproc_per_node 8 lora_finetune_distributed --config llama2/70B_lora +# + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_70b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 16 + lora_alpha: 32 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-70b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-70b-hf + checkpoint_files: [ + pytorch_model-00001-of-00015.bin, + pytorch_model-00002-of-00015.bin, + pytorch_model-00003-of-00015.bin, + pytorch_model-00004-of-00015.bin, + pytorch_model-00005-of-00015.bin, + pytorch_model-00006-of-00015.bin, + pytorch_model-00007-of-00015.bin, + pytorch_model-00008-of-00015.bin, + pytorch_model-00009-of-00015.bin, + pytorch_model-00010-of-00015.bin, + pytorch_model-00011-of-00015.bin, + pytorch_model-00012-of-00015.bin, + pytorch_model-00013-of-00015.bin, + pytorch_model-00014-of-00015.bin, + pytorch_model-00015-of-00015.bin, + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-70b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/llama2/7B_full.yaml b/training/recipes/configs/llama2/7B_full.yaml new file mode 100644 index 0000000..e98d7a6 --- /dev/null +++ b/training/recipes/configs/llama2/7B_full.yaml @@ -0,0 +1,76 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 7B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama2/7B_full_low_memory.yaml b/training/recipes/configs/llama2/7B_full_low_memory.yaml new file mode 100644 index 0000000..b327e73 --- /dev/null +++ b/training/recipes/configs/llama2/7B_full_low_memory.yaml @@ -0,0 +1,78 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config llama2/7B_full_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama2/7B_full_low_memory checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 2e-5 +optimizer_in_bwd: True +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama2/7B_lora.yaml b/training/recipes/configs/llama2/7B_lora.yaml new file mode 100644 index 0000000..0cfcef2 --- /dev/null +++ b/training/recipes/configs/llama2/7B_lora.yaml @@ -0,0 +1,110 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml + + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama2/7B_lora_dpo.yaml b/training/recipes/configs/llama2/7B_lora_dpo.yaml new file mode 100644 index 0000000..59435c7 --- /dev/null +++ b/training/recipes/configs/llama2/7B_lora_dpo.yaml @@ -0,0 +1,86 @@ +# Config for multi-device LoRA DPO alignment in lora_dpo_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama2/7B_lora_dpo +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama2/7B_lora_dpo checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA DPO alignment please use 7B_lora_dpo_single_device.yaml + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchanged_paired_dataset + max_seq_len: 1024 +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.05 + lr: 5e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.DPOLoss + beta: 0.1 + label_smoothing: 0 + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 8 + +# Logging +output_dir: /tmp/lora_dpo_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/training/recipes/configs/llama2/7B_lora_dpo_single_device.yaml new file mode 100644 index 0000000..5478c23 --- /dev/null +++ b/training/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -0,0 +1,86 @@ +# Config for single device LoRA DPO alignment in lora_dpo_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_dpo_single_device --config llama2/7B_lora_dpo_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_dpo_single_device --config llama2/7B_lora_dpo_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchanged_paired_dataset + max_seq_len: 1024 +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.05 + lr: 5e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.DPOLoss + beta: 0.1 + label_smoothing: 0 + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/lora_dpo_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/llama2/7B_lora_single_device.yaml b/training/recipes/configs/llama2/7B_lora_single_device.yaml new file mode 100644 index 0000000..b68f851 --- /dev/null +++ b/training/recipes/configs/llama2/7B_lora_single_device.yaml @@ -0,0 +1,108 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama2/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config 7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama2/7B_qat_full.yaml b/training/recipes/configs/llama2/7B_qat_full.yaml new file mode 100644 index 0000000..3256318 --- /dev/null +++ b/training/recipes/configs/llama2/7B_qat_full.yaml @@ -0,0 +1,76 @@ +# Config for multi-device QAT finetuning in qat_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama2/7B_qat_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama2/7B_qat_full checkpointer.checkpoint_dir= + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# QAT arguments +quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama2/7B_qlora_single_device.yaml b/training/recipes/configs/llama2/7B_qlora_single_device.yaml new file mode 100644 index 0000000..496d02a --- /dev/null +++ b/training/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -0,0 +1,107 @@ +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama2/7B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config 7B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama2.qlora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/qlora_finetune_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama3/70B_full.yaml b/training/recipes/configs/llama3/70B_full.yaml new file mode 100644 index 0000000..c35a0a4 --- /dev/null +++ b/training/recipes/configs/llama3/70B_full.yaml @@ -0,0 +1,109 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3 70B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --hf-token --ignore-patterns "original/consolidated*" +# +# To launch on 8 devices, run the following command from root: +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full checkpointer.checkpoint_dir= +# +# This config is only tested on an 8xA100 machine. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_70b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct + checkpoint_files: [ + model-00001-of-00030.safetensors, + model-00002-of-00030.safetensors, + model-00003-of-00030.safetensors, + model-00004-of-00030.safetensors, + model-00005-of-00030.safetensors, + model-00006-of-00030.safetensors, + model-00007-of-00030.safetensors, + model-00008-of-00030.safetensors, + model-00009-of-00030.safetensors, + model-00010-of-00030.safetensors, + model-00011-of-00030.safetensors, + model-00012-of-00030.safetensors, + model-00013-of-00030.safetensors, + model-00014-of-00030.safetensors, + model-00015-of-00030.safetensors, + model-00016-of-00030.safetensors, + model-00017-of-00030.safetensors, + model-00018-of-00030.safetensors, + model-00019-of-00030.safetensors, + model-00020-of-00030.safetensors, + model-00021-of-00030.safetensors, + model-00022-of-00030.safetensors, + model-00023-of-00030.safetensors, + model-00024-of-00030.safetensors, + model-00025-of-00030.safetensors, + model-00026-of-00030.safetensors, + model-00027-of-00030.safetensors, + model-00028-of-00030.safetensors, + model-00029-of-00030.safetensors, + model-00030-of-00030.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-70b + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + # Note: highly recommended to use fused=True optimizer flag + # with CPU offload for faster optimizer step. + fused: True + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True +fsdp_cpu_offload: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/full-llama3-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama3/70B_lora.yaml b/training/recipes/configs/llama3/70B_lora.yaml new file mode 100644 index 0000000..0dcd9e8 --- /dev/null +++ b/training/recipes/configs/llama3/70B_lora.yaml @@ -0,0 +1,101 @@ +# Config for multi-device LoRA in lora_finetune_distributed.py +# using a Llama3 70B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-70B-Instruct --hf-token --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" +# +# This config needs 8 GPUs to run +# # tune run --nproc_per_node 8 lora_finetune_distributed --config llama3/70B_lora +# + +# Model Arguments +model: + _component_: torchtune.models.llama3.lora_llama3_70b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 16 + lora_alpha: 32 + +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct + checkpoint_files: [ + model-00001-of-00030.safetensors, + model-00002-of-00030.safetensors, + model-00003-of-00030.safetensors, + model-00004-of-00030.safetensors, + model-00005-of-00030.safetensors, + model-00006-of-00030.safetensors, + model-00007-of-00030.safetensors, + model-00008-of-00030.safetensors, + model-00009-of-00030.safetensors, + model-00010-of-00030.safetensors, + model-00011-of-00030.safetensors, + model-00012-of-00030.safetensors, + model-00013-of-00030.safetensors, + model-00014-of-00030.safetensors, + model-00015-of-00030.safetensors, + model-00016-of-00030.safetensors, + model-00017-of-00030.safetensors, + model-00018-of-00030.safetensors, + model-00019-of-00030.safetensors, + model-00020-of-00030.safetensors, + model-00021-of-00030.safetensors, + model-00022-of-00030.safetensors, + model-00023-of-00030.safetensors, + model-00024-of-00030.safetensors, + model-00025-of-00030.safetensors, + model-00026-of-00030.safetensors, + model-00027-of-00030.safetensors, + model-00028-of-00030.safetensors, + model-00029-of-00030.safetensors, + model-00030-of-00030.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-70B-Instruct + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/llama3/8B_full.yaml b/training/recipes/configs/llama3/8B_full.yaml new file mode 100644 index 0000000..47d3dc0 --- /dev/null +++ b/training/recipes/configs/llama3/8B_full.yaml @@ -0,0 +1,78 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 8B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_8b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/full-llama3-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama3/8B_full_single_device.yaml b/training/recipes/configs/llama3/8B_full_single_device.yaml new file mode 100644 index 0000000..8e78f06 --- /dev/null +++ b/training/recipes/configs/llama3/8B_full_single_device.yaml @@ -0,0 +1,77 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config llama3/8B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama3/8B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_8b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW8bit + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +optimizer_in_bwd: True +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/full-llama3-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama3/8B_lora.yaml b/training/recipes/configs/llama3/8B_lora.yaml new file mode 100644 index 0000000..bd4a621 --- /dev/null +++ b/training/recipes/configs/llama3/8B_lora.yaml @@ -0,0 +1,81 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 8B_lora_single_device.yaml +# or 8B_qlora_single_device.yaml + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + +# Model Arguments +model: + _component_: torchtune.models.llama3.lora_llama3_8b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/training/recipes/configs/llama3/8B_lora_single_device.yaml b/training/recipes/configs/llama3/8B_lora_single_device.yaml new file mode 100644 index 0000000..650b2ae --- /dev/null +++ b/training/recipes/configs/llama3/8B_lora_single_device.yaml @@ -0,0 +1,106 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3/8B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3/8B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.llama3.lora_llama3_8b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Profiler (disabled) +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama3/8B_qat_full.yaml b/training/recipes/configs/llama3/8B_qat_full.yaml new file mode 100644 index 0000000..876d0f7 --- /dev/null +++ b/training/recipes/configs/llama3/8B_qat_full.yaml @@ -0,0 +1,77 @@ +# Config for multi-device QAT finetuning in qat_distributed.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full checkpointer.checkpoint_dir= + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_8b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +# QAT arguments +quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama3-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama3/8B_qlora_single_device.yaml b/training/recipes/configs/llama3/8B_qlora_single_device.yaml new file mode 100644 index 0000000..d806927 --- /dev/null +++ b/training/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -0,0 +1,105 @@ +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3/8B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama3.qlora_llama3_8b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/qlora_finetune_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Profiler (disabled) +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama3_1/70B_full.yaml b/training/recipes/configs/llama3_1/70B_full.yaml new file mode 100644 index 0000000..f75b5a7 --- /dev/null +++ b/training/recipes/configs/llama3_1/70B_full.yaml @@ -0,0 +1,109 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3.1 70B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" +# +# To launch on 8 devices, run the following command from root: +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3_1/70B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3_1/70B_full checkpointer.checkpoint_dir= +# +# This config is only tested on an 8xA100 machine. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_70b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ + checkpoint_files: [ + model-00001-of-00030.safetensors, + model-00002-of-00030.safetensors, + model-00003-of-00030.safetensors, + model-00004-of-00030.safetensors, + model-00005-of-00030.safetensors, + model-00006-of-00030.safetensors, + model-00007-of-00030.safetensors, + model-00008-of-00030.safetensors, + model-00009-of-00030.safetensors, + model-00010-of-00030.safetensors, + model-00011-of-00030.safetensors, + model-00012-of-00030.safetensors, + model-00013-of-00030.safetensors, + model-00014-of-00030.safetensors, + model-00015-of-00030.safetensors, + model-00016-of-00030.safetensors, + model-00017-of-00030.safetensors, + model-00018-of-00030.safetensors, + model-00019-of-00030.safetensors, + model-00020-of-00030.safetensors, + model-00021-of-00030.safetensors, + model-00022-of-00030.safetensors, + model-00023-of-00030.safetensors, + model-00024-of-00030.safetensors, + model-00025-of-00030.safetensors, + model-00026-of-00030.safetensors, + model-00027-of-00030.safetensors, + model-00028-of-00030.safetensors, + model-00029-of-00030.safetensors, + model-00030-of-00030.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + # Note: highly recommended to use fused=True optimizer flag + # with CPU offload for faster optimizer step. + fused: True + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True +fsdp_cpu_offload: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/full-llama3_1-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama3_1/70B_lora.yaml b/training/recipes/configs/llama3_1/70B_lora.yaml new file mode 100644 index 0000000..943f715 --- /dev/null +++ b/training/recipes/configs/llama3_1/70B_lora.yaml @@ -0,0 +1,100 @@ +# Config for multi-device LoRA in lora_finetune_distributed.py +# using a Llama3.1 70B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" +# +# This config needs 8 GPUs to run +# tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_1/70B_lora + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_70b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 16 + lora_alpha: 32 + +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ + checkpoint_files: [ + model-00001-of-00030.safetensors, + model-00002-of-00030.safetensors, + model-00003-of-00030.safetensors, + model-00004-of-00030.safetensors, + model-00005-of-00030.safetensors, + model-00006-of-00030.safetensors, + model-00007-of-00030.safetensors, + model-00008-of-00030.safetensors, + model-00009-of-00030.safetensors, + model-00010-of-00030.safetensors, + model-00011-of-00030.safetensors, + model-00012-of-00030.safetensors, + model-00013-of-00030.safetensors, + model-00014-of-00030.safetensors, + model-00015-of-00030.safetensors, + model-00016-of-00030.safetensors, + model-00017-of-00030.safetensors, + model-00018-of-00030.safetensors, + model-00019-of-00030.safetensors, + model-00020-of-00030.safetensors, + model-00021-of-00030.safetensors, + model-00022-of-00030.safetensors, + model-00023-of-00030.safetensors, + model-00024-of-00030.safetensors, + model-00025-of-00030.safetensors, + model-00026-of-00030.safetensors, + model-00027-of-00030.safetensors, + model-00028-of-00030.safetensors, + model-00029-of-00030.safetensors, + model-00030-of-00030.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Logging +output_dir: /tmp/lora-llama3_1-finetune-output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True diff --git a/training/recipes/configs/llama3_1/8B_full.yaml b/training/recipes/configs/llama3_1/8B_full.yaml new file mode 100644 index 0000000..8d84728 --- /dev/null +++ b/training/recipes/configs/llama3_1/8B_full.yaml @@ -0,0 +1,92 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 8B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /fsx-onellm/akariasai/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + +# Dataset +# Open-Orca/SlimOrca-Dedup +dataset: + _component_: torchtune.datasets.chat_dataset + train_on_input: False + source: /fsx-onellm/swj0419/science_rag/torchtune/swj_code/data/home/akariasai/expert_lm + chat_format: torchtune.data.Llama3ChatFormat + #source: json + #chat_format: Llama2ChatFormat + #source: Open-Orca/SlimOrca-Dedup + max_seq_len: 8096 + conversation_style: openai + split: train +seed: null +shuffle: true + + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /fsx-onellm/akariasai/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /fsx-onellm/swj0419/science_rag/torchtune/swj_code/output/Meta-Llama-3.1-8B-Instruct_tmp/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: llama3.1-finetune_tmp +log_every_n_steps: 1 +log_peak_memory_stats: False \ No newline at end of file diff --git a/training/recipes/configs/llama3_1/8B_full_single_device.yaml b/training/recipes/configs/llama3_1/8B_full_single_device.yaml new file mode 100644 index 0000000..5da4038 --- /dev/null +++ b/training/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -0,0 +1,80 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW8bit + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +optimizer_in_bwd: True +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/full-llama3.1-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/llama3_1/8B_lora.yaml b/training/recipes/configs/llama3_1/8B_lora.yaml new file mode 100644 index 0000000..e2ff9fb --- /dev/null +++ b/training/recipes/configs/llama3_1/8B_lora.yaml @@ -0,0 +1,84 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 8B_lora_single_device.yaml +# or 8B_qlora_single_device.yaml + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/training/recipes/configs/llama3_1/8B_lora_single_device.yaml b/training/recipes/configs/llama3_1/8B_lora_single_device.yaml new file mode 100644 index 0000000..57a7b70 --- /dev/null +++ b/training/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -0,0 +1,109 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Profiler (disabled) +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/training/recipes/configs/llama3_1/8B_qlora_single_device.yaml new file mode 100644 index 0000000..3475fce --- /dev/null +++ b/training/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -0,0 +1,108 @@ +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.qlora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/qlora_finetune_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Profiler (disabled) +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/mistral/7B_full.yaml b/training/recipes/configs/mistral/7B_full.yaml new file mode 100644 index 0000000..6643221 --- /dev/null +++ b/training/recipes/configs/mistral/7B_full.yaml @@ -0,0 +1,78 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Mistral 7B model +# +# This config uses hyperparameters based on small set of experiments and information +# available on various forums. These are not meant to replicate the numbers +# from the paper +# +# This config assumes that you've run the following command before launching +# this run: +# tune download mistralai/Mistral-7B-v0.1 --hf-token --output-dir /tmp/Mistral-7B-v0.1 +# +# Run this config on 4 GPUs using the following: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config mistral/7B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config mistral/7B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 7B_full_single_device.yaml for those cases + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-v0.1/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.mistral.mistral_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-v0.1 + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Mistral-7B-v0.1/ + model_type: MISTRAL +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Mistral-7B-v0.1/ +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/mistral/7B_full_low_memory.yaml b/training/recipes/configs/mistral/7B_full_low_memory.yaml new file mode 100644 index 0000000..023dc89 --- /dev/null +++ b/training/recipes/configs/mistral/7B_full_low_memory.yaml @@ -0,0 +1,83 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Mistral 7B model +# +# This config uses hyperparameters based on small set of experiments and information +# available on various forums. These are not meant to replicate the numbers +# from the paper +# +# This config assumes that you've run the following command before launching +# this run: +# tune download mistralai/Mistral-7B-v0.1 --hf-token --output-dir /tmp/Mistral-7B-v0.1 +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config mistral/7B_full_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-v0.1/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.mistral.mistral_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-v0.1 + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Mistral-7B-v0.1/ + model_type: MISTRAL +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 5e-6 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +optimizer_in_bwd: True + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Model compilation +compile: False + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Mistral-7B-v0.1/ +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/training/recipes/configs/mistral/7B_full_ppo_low_memory.yaml new file mode 100644 index 0000000..7ccf510 --- /dev/null +++ b/training/recipes/configs/mistral/7B_full_ppo_low_memory.yaml @@ -0,0 +1,180 @@ +# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py +# using a Mistral 7B model. +# +# This config has been tested on an A100 80GB. +# This config uses hyperparameters based on small set of experiments and information +# available from existing implementations. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/ --ignore-patterns="" +# tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --hf-token HF_TOKEN +# +# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders. +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run ppo_full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir= +# + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.text_completion_dataset + source: trl-internal-testing/sentiment-trl-style + max_seq_len: null + split: train + column: prompt + add_eos: False + +policy_model: + _component_: torchtune.models.mistral.mistral_7b + +# we need to manually build the mistral classifier model +# because our reward model checkpoint has a larger vocabulary size (due to an added padding token) +reward_and_value_model: + _component_: torchtune.models.mistral._component_builders.mistral_classifier + attn_dropout: 0.0 + embed_dim: 4096 + intermediate_dim: 14336 + max_seq_len: 32768 + norm_eps: 1.0e-05 + num_classes: 1 + num_heads: 32 + num_kv_heads: 8 + num_layers: 32 + vocab_size: 32001 + +# checkpointer for the policy model - update this if resuming from checkpoint +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/ + checkpoint_files: [ + "pytorch_model-00001-of-00003.bin", + "pytorch_model-00002-of-00003.bin", + "pytorch_model-00003-of-00003.bin" + ] + # this is the only place where you should update `recipe_checkpoint` if resuming training + recipe_checkpoint: null + output_dir: ${output_dir}/policy + model_type: MISTRAL + +# this should be setup identically to the policy model checkpointer at the start of training +# ensure `checkpoint_files` always points to the original policy weights, even if resuming training +ref_policy_checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/ + checkpoint_files: [ + "pytorch_model-00001-of-00003.bin", + "pytorch_model-00002-of-00003.bin", + "pytorch_model-00003-of-00003.bin" + ] + output_dir: ${output_dir}/policy + model_type: MISTRAL + +# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint +# since this model will be identical to the reward model it's helpful to initialise this +# from the trained reward model weights +value_checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/RM-Mistral-7B/ + checkpoint_files: [ + "model-00001-of-00003.safetensors", + "model-00002-of-00003.safetensors", + "model-00003-of-00003.safetensors" + ] + output_dir: ${output_dir}/value + model_type: REWARD + +# checkpointer for the reward model, ensure `checkpoint_files` +# always points to the original reward model weights, even if resuming training +reward_checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/RM-Mistral-7B/ + checkpoint_files: [ + "model-00001-of-00003.safetensors", + "model-00002-of-00003.safetensors", + "model-00003-of-00003.safetensors" + ] + output_dir: ${output_dir}/value + model_type: REWARD + + +resume_from_checkpoint: False +output_dir: /tmp/mistral7b-ppo-finetune +seed: null +shuffle: True + +# Training env +device: cuda + +# Training arguments +batch_size: 64 +num_steps: 10000 +ppo_epochs: 2 +ppo_batch_size: 32 +gradient_accumulation_steps: 1 + +# Memory management and performance +compile: True +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 3e-6 +optimizer_in_bwd: True +log_peak_memory_stats: False +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + + +# batch size for forward pass during generation +forward_batch_size: 16 +max_generated_tokens: 58 +temperature: 0.7 +top_k: null + +# parameter for penalising generations shorter than `min_response_length` +min_response_length: 18 +# parameter for penalising generations without a stop token +penalise_no_eos: True +# scalar penalty to apply when penalising +reward_penalty: -3 + +# tokens to consider as "end of sequence" tokens +stop_token_ids: [ + 2, # eos_id + 28723 # mistral "." token +] +whiten_rewards: False + +# GAE hyperparameters +gamma: 1 +lmbda: 0.95 + +# PPO hyperparameters +loss: + _component_: torchtune.modules.loss.PPOLoss + epsilon: 0.2 + value_coeff: 0.1 + value_clip_range: 0.2 +kl_coeff: 0.01 + + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} + +log_every_n_steps: 1 diff --git a/training/recipes/configs/mistral/7B_lora.yaml b/training/recipes/configs/mistral/7B_lora.yaml new file mode 100644 index 0000000..a84b714 --- /dev/null +++ b/training/recipes/configs/mistral/7B_lora.yaml @@ -0,0 +1,90 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Mistral 7B model +# +# This config uses hyperparameters based on small set of experiments and information +# available on various forums. These are not meant to replicate the numbers +# from the paper +# +# This config assumes that you've run the following command before launching +# this run: +# tune download mistralai/Mistral-7B-v0.1 --hf-token --output-dir /tmp/Mistral-7B-v0.1 +# +# Run this config on 2 GPUs using the following: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config mistral/7B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config mistral/7B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-v0.1/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.mistral.lora_mistral_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-v0.1 + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Mistral-7B-v0.1 + model_type: MISTRAL +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Mistral-7B-v0.1 +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/mistral/7B_lora_single_device.yaml b/training/recipes/configs/mistral/7B_lora_single_device.yaml new file mode 100644 index 0000000..23524b7 --- /dev/null +++ b/training/recipes/configs/mistral/7B_lora_single_device.yaml @@ -0,0 +1,114 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Mistral 7B model +# +# This config uses hyperparameters based on small set of experiments and information +# available on various forums. These are not meant to replicate the numbers +# from the paper +# +# This config assumes that you've run the following command before launching +# this run: +# tune download mistralai/Mistral-7B-v0.1 --hf-token --output-dir /tmp/Mistral-7B-v0.1 +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config mistral/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config mistral/7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-v0.1/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.mistral.lora_mistral_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: True + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-v0.1 + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Mistral-7B-v0.1 + model_type: MISTRAL +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Mistral-7B-v0.1 +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/mistral/7B_qlora_single_device.yaml b/training/recipes/configs/mistral/7B_qlora_single_device.yaml new file mode 100644 index 0000000..f571ce6 --- /dev/null +++ b/training/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -0,0 +1,115 @@ +# Config for single device QLoRA finetuning in lora_finetune_single_device.py +# using a Mistral 7B model +# +# This config uses hyperparameters based on small set of experiments and information +# available on various forums. These are not meant to replicate the numbers +# from the paper +# +# This config assumes that you've run the following command before launching +# this run: +# tune download mistralai/Mistral-7B-v0.1 --hf-token --output-dir /tmp/Mistral-7B-v0.1 +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config mistral/7B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config mistral/7B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-v0.1/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.mistral.qlora_mistral_7b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 64 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-v0.1 + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Mistral-7B-v0.1 + model_type: MISTRAL +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Mistral-7B-v0.1 +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/phi3/mini_full.yaml b/training/recipes/configs/phi3/mini_full.yaml new file mode 100644 index 0000000..82775ad --- /dev/null +++ b/training/recipes/configs/phi3/mini_full.yaml @@ -0,0 +1,73 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Phi3 Mini 4K Instruct +# +# This config assumes that you've run the following command before launching +# this run: +# tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --ignore-patterns None --hf-token +# +# Run this config on 4 GPUs using the following: +# tune run --nproc_per_node 4 full_finetune_distributed --config phi3/mini_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 full_finetune_distributed --config phi3/mini_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use mini_low_memory.yaml for those cases + +# Model arguments +model: + _component_: torchtune.models.phi3.phi3_mini + +# Tokenizer +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + +# Checkpointer +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct + model_type: PHI3_MINI +resume_from_checkpoint: False + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 2 +gradient_accumulation_steps: 16 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False +dtype: bf16 + +# Logging +output_dir: /tmp/phi3_full_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: /tmp/Phi-3-mini-4k-instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/phi3/mini_full_low_memory.yaml b/training/recipes/configs/phi3/mini_full_low_memory.yaml new file mode 100644 index 0000000..73ea2c4 --- /dev/null +++ b/training/recipes/configs/phi3/mini_full_low_memory.yaml @@ -0,0 +1,76 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Phi3 Mini 4K Instruct +# +# This config assumes that you've run the following command before launching +# this run: +# tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --ignore-patterns None --hf-token +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config phi3/mini_full_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config phi3/mini_full_low_memory checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model arguments +model: + _component_: torchtune.models.phi3.phi3_mini + +# Tokenizer +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + +# Checkpointer +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct + model_type: PHI3_MINI +resume_from_checkpoint: False + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 2 +gradient_accumulation_steps: 1 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 5e-6 +optimizer_in_bwd: True +loss: + _component_: torch.nn.CrossEntropyLoss +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +output_dir: /tmp/phi3_lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: /tmp/Phi-3-mini-4k-instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/phi3/mini_lora.yaml b/training/recipes/configs/phi3/mini_lora.yaml new file mode 100644 index 0000000..83de413 --- /dev/null +++ b/training/recipes/configs/phi3/mini_lora.yaml @@ -0,0 +1,82 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Phi3 mini (3.8B) model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --ignore-patterns None --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use mini_lora_single_device.yaml +# or mini_qlora_single_device.yaml + +# Model arguments +model: + _component_: torchtune.models.phi3.lora_phi3_mini + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + +# Checkpointer +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct + model_type: PHI3_MINI +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 2 +gradient_accumulation_steps: 32 +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: False +dtype: bf16 + +# Logging +output_dir: /tmp/phi3_lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: /tmp/Phi-3-mini-4k-instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/phi3/mini_lora_single_device.yaml b/training/recipes/configs/phi3/mini_lora_single_device.yaml new file mode 100644 index 0000000..030361b --- /dev/null +++ b/training/recipes/configs/phi3/mini_lora_single_device.yaml @@ -0,0 +1,107 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Phi3 mini (3.8B) model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --ignore-patterns None --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config phi3/mini_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config phi3/mini_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model arguments +model: + _component_: torchtune.models.phi3.lora_phi3_mini + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + +# Checkpointer +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct + model_type: PHI3_MINI +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +batch_size: 2 +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +loss: + _component_: torch.nn.CrossEntropyLoss +compile: False + +# Training env +device: cuda + +# Memory management +dtype: bf16 +enable_activation_checkpointing: True + +# Logging +output_dir: /tmp/phi3_lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: /tmp/Phi-3-mini-4k-instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Showcase the usage of PyTorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: /tmp/Phi-3-mini-4k-instruct/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/phi3/mini_qlora_single_device.yaml b/training/recipes/configs/phi3/mini_qlora_single_device.yaml new file mode 100644 index 0000000..4de6576 --- /dev/null +++ b/training/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -0,0 +1,107 @@ +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Phi3 mini (3.8B) model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --ignore-patterns None --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config phi3/mini_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config phi3/mini_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model arguments +model: + _component_: torchtune.models.phi3.qlora_phi3_mini + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + +# Checkpointer +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct + model_type: PHI3_MINI +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Fine-tuning arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 2 +gradient_accumulation_steps: 16 +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +loss: + _component_: torch.nn.CrossEntropyLoss +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +output_dir: /tmp/phi3_qlora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: /tmp/Phi-3-mini-4k-instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Showcase the usage of PyTorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + # Output directory of trace artifacts + output_dir: /tmp/Phi-3-mini-4k-instruct/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/quantization.yaml b/training/recipes/configs/quantization.yaml new file mode 100644 index 0000000..10feac2 --- /dev/null +++ b/training/recipes/configs/quantization.yaml @@ -0,0 +1,28 @@ +# Config for QuantizationRecipe in quantize.py +# +# To launch, run the following command from root torchtune directory: +# tune run quantize --config quantization + +# +# Model arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin, + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 + +device: cuda +dtype: bf16 +seed: 1234 + +quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer + groupsize: 256 diff --git a/training/recipes/configs/qwen2/0.5B_full.yaml b/training/recipes/configs/qwen2/0.5B_full.yaml new file mode 100644 index 0000000..159d43e --- /dev/null +++ b/training/recipes/configs/qwen2/0.5B_full.yaml @@ -0,0 +1,75 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Qwen2 0.5B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2/0.5B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2/0.5B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 0.5B_full.yaml for those cases + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_0_5b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 16 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: False +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Qwen2-0.5B-Instruct-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/qwen2/0.5B_full_single_device.yaml b/training/recipes/configs/qwen2/0.5B_full_single_device.yaml new file mode 100644 index 0000000..5e2766c --- /dev/null +++ b/training/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -0,0 +1,75 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2 0.5B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_0_5b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +loss: + _component_: torch.nn.CrossEntropyLoss +optimizer_in_bwd: False + +max_steps_per_epoch: null +gradient_accumulation_steps: 8 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Qwen2-0.5B-Instruct-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/qwen2/0.5B_lora.yaml b/training/recipes/configs/qwen2/0.5B_lora.yaml new file mode 100644 index 0000000..60e37e1 --- /dev/null +++ b/training/recipes/configs/qwen2/0.5B_lora.yaml @@ -0,0 +1,110 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Qwen2 0.5B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/0.5B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/0.5B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 0.5B_lora_single_device.yaml +# or 0.5B_qlora_single_device.yaml + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_0_5b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 2e-3 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 + +# Logging +output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} + +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/training/recipes/configs/qwen2/0.5B_lora_single_device.yaml new file mode 100644 index 0000000..b8ff128 --- /dev/null +++ b/training/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -0,0 +1,107 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Qwen2 0.5B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config qwen2/0.5B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config qwen2/0.5B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_0_5b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune + model_type: QWEN2 + +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 2e-3 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Logging +output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/qwen2/1.5B_full.yaml b/training/recipes/configs/qwen2/1.5B_full.yaml new file mode 100644 index 0000000..deec3fb --- /dev/null +++ b/training/recipes/configs/qwen2/1.5B_full.yaml @@ -0,0 +1,75 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Qwen2 1.5B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2/1.5B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2/1.5B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 1.5B_full.yaml for those cases + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-1.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-1.5B-Instruct/merges.txt + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_1_5b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-1.5B-Instruct-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: False +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Qwen2-1.5B-Instruct-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/qwen2/1.5B_full_single_device.yaml b/training/recipes/configs/qwen2/1.5B_full_single_device.yaml new file mode 100644 index 0000000..489e670 --- /dev/null +++ b/training/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -0,0 +1,81 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2 1.5B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2/1.5B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2/1.5B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-1.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-1.5B-Instruct/merges.txt + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_1_5b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-1.5B-Instruct-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 2e-5 + +optimizer_in_bwd: True + +loss: + _component_: torch.nn.CrossEntropyLoss + +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Qwen2-1.5B-Instruct-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/qwen2/1.5B_lora.yaml b/training/recipes/configs/qwen2/1.5B_lora.yaml new file mode 100644 index 0000000..9655005 --- /dev/null +++ b/training/recipes/configs/qwen2/1.5B_lora.yaml @@ -0,0 +1,105 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Qwen2 0.5B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_1_5b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 + +# Logging +output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/training/recipes/configs/qwen2/1.5B_lora_single_device.yaml new file mode 100644 index 0000000..a663f25 --- /dev/null +++ b/training/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -0,0 +1,105 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Qwen2 1.5B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config qwen2/1.5B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config qwen2/1.5B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_1_5b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-1.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-1.5B-Instruct/merges.txt + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + lr: 2e-3 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 +compile: False + +# Logging +output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/qwen2/7B_full.yaml b/training/recipes/configs/qwen2/7B_full.yaml new file mode 100644 index 0000000..e4c67f1 --- /dev/null +++ b/training/recipes/configs/qwen2/7B_full.yaml @@ -0,0 +1,78 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Qwen2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct --ignore-patterns None +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2/7B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2/7B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 7B_full.yaml for those cases + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-7B-Instruct/vocab.json + merges_file: /tmp/Qwen2-7B-Instruct/merges.txt + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-7B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-7B-Instruct-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 16 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Qwen2-7B-Instruct-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/qwen2/7B_full_single_device.yaml b/training/recipes/configs/qwen2/7B_full_single_device.yaml new file mode 100644 index 0000000..88ea138 --- /dev/null +++ b/training/recipes/configs/qwen2/7B_full_single_device.yaml @@ -0,0 +1,80 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2 7B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct --ignore-patterns None +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2/7B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2/7B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-7B-Instruct/vocab.json + merges_file: /tmp/Qwen2-7B-Instruct/merges.txt + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-7B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-7B-Instruct-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 5e-6 +optimizer_in_bwd: True +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/Qwen2-7B-Instruct-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/training/recipes/configs/qwen2/7B_lora.yaml b/training/recipes/configs/qwen2/7B_lora.yaml new file mode 100644 index 0000000..8f27082 --- /dev/null +++ b/training/recipes/configs/qwen2/7B_lora.yaml @@ -0,0 +1,111 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Qwen2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct --ignore-patterns None +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/7B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/7B_lora checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-7B-Instruct/vocab.json + merges_file: /tmp/Qwen2-7B-Instruct/merges.txt + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-7B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/configs/qwen2/7B_lora_single_device.yaml b/training/recipes/configs/qwen2/7B_lora_single_device.yaml new file mode 100644 index 0000000..0357273 --- /dev/null +++ b/training/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -0,0 +1,109 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Qwen2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct --ignore-patterns None +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config qwen2/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config qwen2/7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-7B-Instruct/vocab.json + merges_file: /tmp/Qwen2-7B-Instruct/merges.txt + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-7B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune + model_type: QWEN2 +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +# Logging +output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/training/recipes/dev/fsdp2_recipes.md b/training/recipes/dev/fsdp2_recipes.md new file mode 100644 index 0000000..0df1571 --- /dev/null +++ b/training/recipes/dev/fsdp2_recipes.md @@ -0,0 +1,12 @@ +## FSDP2 Recipes + +This directory contains distributed training recipes for LoRA and QLoRA using [FSDP2](https://github.com/pytorch/pytorch/issues/114299). +Currently FSDP2 is only available in PyTorch nightly releases. + +To set up your environment to run these recipes, you should first install torchtune dependencies, +then install PyTorch nightlies. E.g. + +``` +pip install torchtune +pip3 install --upgrade --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 +``` diff --git a/training/recipes/dev/lora_finetune_fsdp2.py b/training/recipes/dev/lora_finetune_fsdp2.py new file mode 100644 index 0000000..73b4a46 --- /dev/null +++ b/training/recipes/dev/lora_finetune_fsdp2.py @@ -0,0 +1,671 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, +) + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft import LoRALinear +from torchtune.modules.peft.peft_utils import ( + get_adapter_params, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRAFinetuneRecipeDistributed(FTRecipeInterface): + """ + Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not + supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + ValueError: If world_size is 1 + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + """ + + def __init__(self, cfg: DictConfig) -> None: + + if not utils.torch_version_ge("2.4.0"): + raise RuntimeError("FSDP2 recipe is only available on PyTorch nightlies") + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + _, rank = utils.get_world_size_and_rank() + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # training attributes + self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + # When resuming from checkpoint for LoRA, the recipe expects the adapter weights + # and recipe state to be present. The keys should match up with what ``save_checkpoint`` + # used to create these intermediate checkpoints + if self._resume_from_checkpoint: + if utils.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + if not ( + utils.SEED_KEY in ckpt_dict + and utils.TOTAL_EPOCHS_KEY in ckpt_dict + and utils.MAX_STEPS_KEY in ckpt_dict + ): + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state." + "Are you sure you passed in the right recipe checkpoint?" + ) + # If seed, total_epoch or max_steps_per_epoch don't match, + # warn the user and overwrite + if ( + self.seed != ckpt_dict[utils.SEED_KEY] + or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] + ): + warn( + message="""Configured value for seed, epochs or max_steps_per_epoch + does not match the value stored in checkpoint.""" + ) + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[utils.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + cfg_fsdp=cfg.fsdp if hasattr(cfg, "fsdp") else None, + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None, + ) + + self._loss_fn = config.instantiate(cfg.loss) + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + cfg_fsdp: Optional[Union[DictConfig, None]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` + """ + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + fsdp_kwargs = {} + if cfg_fsdp and cfg_fsdp.cpu_offload: + from torch.distributed._composable.fsdp import CPUOffloadPolicy + + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + # iterating from lowerer modules to higher + # eg grouping lora adapters before transformer block + for m in reversed(list(model.modules())): + if isinstance(m, nn.Linear) and m.weight.requires_grad: + fully_shard(m, **fsdp_kwargs) + # TransformerDecoderLayer is wrapped by CheckpointWrapper + # when enable_activation_checkpointing + if enable_activation_checkpointing: + if isinstance(m, CheckpointWrapper): + fully_shard(m, **fsdp_kwargs) + else: + if isinstance(m, modules.TransformerDecoderLayer): + fully_shard(m, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + + if lora_weights_state_dict: + lora_missing, lora_unexpected = utils.load_from_full_model_state_dict( + model, lora_weights_state_dict, self._device, self._is_rank_zero + ) + else: + lora_missing, lora_unexpected = None, None + + with utils.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if cfg_fsdp and cfg_fsdp.cpu_offload else self._device + for m in model.modules(): + if isinstance(m, LoRALinear) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.lora_a.to_empty(device=lora_device) + m.lora_b.to_empty(device=lora_device) + m.initialize_parameters() + # RoPE is not covered in state dict + if isinstance(m, modules.RotaryPositionalEmbeddings): + m.reset_parameters() + + base_missing, base_unexpected = utils.load_from_full_model_state_dict( + model, base_model_state_dict, self._device, self._is_rank_zero + ) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Ensure no params and buffers are on meta device + utils.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + utils.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = utils.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None, + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the merged weights, adapter weights and recipe state in + different checkpoint files. To correctly resume from training, the adapter weights + and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = utils.get_full_model_state_dict( + self._model, + self._is_rank_zero, + ) + + if intermediate_checkpoint: + opt_state_dict = utils.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) + } + checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + checkpoint_dict.update({utils.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + utils.cleanup_before_training() + + _, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not utils.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) + + recipe = LoRAFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/eleuther_eval.py b/training/recipes/eleuther_eval.py new file mode 100644 index 0000000..105e624 --- /dev/null +++ b/training/recipes/eleuther_eval.py @@ -0,0 +1,283 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from typing import Any, Dict, List, Tuple, Union + +import torch +from omegaconf import DictConfig + +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from torchtune import config, utils +from torchtune.modules import TransformerDecoder +from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.recipe_interfaces import EvalRecipeInterface + + +logger = utils.get_logger("DEBUG") + +try: + import lm_eval + from lm_eval.evaluator import evaluate + from lm_eval.models.huggingface import HFLM + from lm_eval.tasks import get_task_dict, TaskManager + from lm_eval.utils import make_table +except ImportError: + logger.error( + "Recipe requires EleutherAI Eval Harness v0.4. Please install with `pip install lm_eval==0.4.*`" + ) + sys.exit(1) + + +class _EvalWrapper(HFLM): + """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's + EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. + + Args: + model (TransformerDecoder): The model to evaluate. + tokenizer (ModelTokenizer): Tokenizer associated with the model being evaluated. + This should be the same tokenizer used when fine-tuning the model. + device (torch.device): The device to use. + max_seq_length (int): The maximum sequence length to use. + batch_size (int): The batch size per GPU to use. + dtype (torch.dtype): dtype for the model caches during generation. + """ + + def __init__( + self, + model: TransformerDecoder, + tokenizer: ModelTokenizer, + *, + device: torch.device, + max_seq_length: int = 4096, + batch_size: int = 8, + dtype: torch.dtype = torch.float32, + ): + super().__init__(pretrained="gpt2", device=str(device)) + self._model = model + self._tokenizer = tokenizer + self._max_seq_length = max_seq_length + self._batch_size = batch_size + self._dtype = dtype + + @property + def model(self): + return self._model + + @property + def eot_token_id(self): + return self._tokenizer.eos_id + + @property + def max_length(self): + return self._max_seq_length + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self._batch_size + + @property + def device(self): + return self._device + + def tok_encode(self, text: str, **kwargs) -> List[int]: + # Note on add_bos flag: setting to False as this gives better results, for example + # +1% on truthfulqa_mc2 with a LoRA finetune. lit-gpt also sets this to False, + # see https://github.com/Lightning-AI/lit-gpt/blob/main/eval/lm_eval_harness.py#L66, + # though notably fast-gpt does the opposite + # https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py#L123. + return self._tokenizer.encode(text=text, add_bos=False, add_eos=False) + + def tok_batch_encode( + self, text: List[str], **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + tokenized_text = [self.tok_encode(x) for x in text] + + # pad left + x = pad_sequence( + [ + torch.tensor(x[::-1]) for x in tokenized_text + ], # first flip each sequence and pad + batch_first=True, + padding_value=self._tokenizer.pad_id, + ).flip( + dims=[1] + ) # flip back to correct order + + return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness + + def tok_decode(self, tokens: Union[List[int], int], **kwargs) -> str: + if isinstance(tokens, int): + tokens = [tokens] + return self._tokenizer.decode(tokens) + + def _model_call(self, inps: torch.Tensor, **kwargs) -> torch.Tensor: + return self._model(inps) + + def _model_generate( + self, context: torch.Tensor, **generation_kwargs + ) -> torch.Tensor: + curr_batch_size = context.size(0) + + if curr_batch_size > 1: + raise ValueError( + f"Got a batch size of '{curr_batch_size}'. Batch size > 1 is not supported for " + "generation. See https://github.com/pytorch/torchtune/issues/1250 for more info." + ) + + # Setup caches for a given batch size + # Technically this is not necessary, but it's a good way to ensure that + # the caches won't error on a different batch size. In addition, caches + # are not needed for a regular model call, so we just setup here + with context.device: + self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype) + + temperature = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", False) + if do_sample: + # do_sample signifies more complicated sampling logic, like top_k or + # top_p. We don't support this yet, so if it's requested, we raise an error. + raise RuntimeError( + "``do_sample`` for generation tasks is not supported yet in torchtune." + ) + + toks = utils.generate( + self._model, + context, + max_generated_tokens=self.max_gen_toks, + temperature=temperature, + top_k=None, # do_sample is not supported currently + stop_tokens=self._tokenizer.stop_tokens, + ) + return torch.tensor(toks, dtype=torch.int32) + + +class EleutherEvalRecipe(EvalRecipeInterface): + """ + This recipe runs evaluation on a trained model using EleutherAI's eval harness. + This assumes the user has the EleutherAI eval harness installed. See + https://github.com/EleutherAI/lm-evaluation-harness for more details. + + Features: + - Single GPU evaluation. Multi-GPU evaluation is currently not supported. + - Loading model in fp32 or bf16. Fp16 is currently not supported. + - Any task from the EleutherAI eval harness that is *not* free generation + + We recommend launching evaluation using the tune CLI: + + tune run eleuther_eval --config llama2_eleuther_eval \ + tasks=["truthfulqa_mc2","hellaswag"] + + Args: + cfg (DictConfig): OmegaConf object parsed from YAML file + """ + + def __init__(self, cfg: DictConfig) -> None: + self._cfg = cfg + + def setup(self) -> None: + self._device = utils.get_device(device=self._cfg.device) + self._dtype = utils.get_dtype(dtype=self._cfg.dtype, device=self._device) + self._limit = self._cfg.limit + self._tasks = list(self._cfg.tasks) + self._quantizer = config.instantiate(self._cfg.quantizer) + self._quantization_mode = utils.get_quantizer_mode(self._quantizer) + + utils.set_seed(seed=self._cfg.seed) + + checkpointer = config.instantiate(self._cfg.checkpointer) + if self._quantization_mode is None: + ckpt_dict = checkpointer.load_checkpoint() + else: + # weights_only needs to be False when loading a quantized model + # currently loading a quantized model is only supported with the + # FullModelTorchTuneCheckpointer + ckpt_dict = checkpointer.load_checkpoint(weights_only=False) + + self._model = self._setup_model( + model_cfg=self._cfg.model, + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ) + self._tokenizer = config.instantiate(self._cfg.tokenizer) + logger.info("Tokenizer is initialized from file.") + + def _setup_model( + self, + model_cfg: DictConfig, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + if self._quantization_mode is not None: + model = self._quantizer.quantize(model) + model = model.to(device=self._device, dtype=self._dtype) + + model.load_state_dict(model_state_dict) + + # Put model in eval mode. + # Note: This will not disable the dropout applied in SDPA, + # see https://github.com/pytorch/pytorch/issues/124464 + model.eval() + + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + logger.info(f"Model is initialized with precision {self._dtype}.") + return model + + @torch.no_grad() + def evaluate(self) -> None: + t1 = time.time() + + model_eval_wrapper = _EvalWrapper( + self._model, + self._tokenizer, + device=self._device, + max_seq_length=self._cfg.max_seq_length, + batch_size=self._cfg.batch_size, + dtype=self._dtype, + ) + + # Task initialization API changed between v0.4.1 and 0.4.2 + try: + lm_eval.tasks.initialize_tasks() + except Exception: + pass + + task_manager = TaskManager(include_path=self._cfg.get("include_path", None)) + task_dict = get_task_dict(self._tasks, task_manager) + + logger.info(f"Running evaluation on {self._tasks} tasks.") + output = evaluate( + model_eval_wrapper, + task_dict, + limit=self._limit, + ) + + logger.info(f"Eval completed in {time.time() - t1:.02f} seconds.") + + formatted_output = make_table(output) + print(formatted_output) + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """Entry point for the recipe.""" + config.log_config(recipe_name="EleutherEvalRecipe", cfg=cfg) + recipe = EleutherEvalRecipe(cfg=cfg) + recipe.setup() + recipe.evaluate() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/full_finetune_distributed.py b/training/recipes/full_finetune_distributed.py new file mode 100644 index 0000000..29456fe --- /dev/null +++ b/training/recipes/full_finetune_distributed.py @@ -0,0 +1,683 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import init_process_group +from torch.distributed.fsdp import ( + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.utils import DummyProfiler, PROFILER_KEY +from torchtune.utils.activations import apply_selective_activation_checkpointing + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class FullFinetuneRecipeDistributed(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU + is not supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.optimizer.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = utils.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=ckpt_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None, + ) + + self._loss_fn = config.instantiate(cfg.loss) + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + memory_efficient_fsdp_wrap: bool, + fsdp_cpu_offload: bool, + model_state_dict: Dict[str, Any], + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. + """ + if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + + log.info( + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + ) + + # Load both the model weights. This should happen only on Rank 0 + model.load_state_dict(model_state_dict) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + ac_mode = ac_mode + ac_option = ac_option + + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # Wrap the model with FSDP. This will ensure that the model is sharded + # across all available GPUs. + model = FSDP( + module=model, + auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, + modules_to_wrap={modules.TransformerDecoderLayer}, + ), + cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if not self._is_rank_zero + else None + ), + ) + + # Ensure no params and buffers are on meta device + utils.validate_no_params_on_meta_device(model) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + """ + Set up the optimizer. This method also handles transforing the state dict + for FSDP. + """ + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + + if opt_state_dict: + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict + ) + optimizer.load_state_dict(opt_state_dict) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = utils.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None, + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + checkpoint_dict = {} + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + with FSDP.state_dict_type( + self._model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + checkpoint_dict.update( + { + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + # clean up before training begins + utils.cleanup_before_training() + + _, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + torch.distributed.destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not utils.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + utils.set_torch_num_threads() + + config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) + + recipe = FullFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/full_finetune_single_device.py b/training/recipes/full_finetune_single_device.py new file mode 100644 index 0000000..b77dbb3 --- /dev/null +++ b/training/recipes/full_finetune_single_device.py @@ -0,0 +1,600 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.utils import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class FullFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Optimizer in Backward. Fusing the optimizer step into the backward pass helps reduce the memory + footprint associated with gradients. This can be especially helpful when you are memory + constrained. Note that users can only use ONE of gradient accumulation or optimizer in backward. + These features currently do not work together. For more details on optimizer in backward, please + see this tutorial: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with + 8-bit AdamW and Paged AdamW. These optimizers are especially helpful when you are memory constrained + since they help reduce the memory footprint associated with the optimizer states. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + RuntimeError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._dtype == torch.float16: + raise RuntimeError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.optimizer_in_bwd + + # TODO: find a better place / way to perform validation of args that don't yet + # compose with each other. + if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model_compile = cfg.compile + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=self._model_compile, + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=cfg.optimizer_in_bwd, + opt_state_dict=( + ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + """ + Set up the model including enabling activation checkpointing. + """ + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + model.load_state_dict(model_state_dict) + + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + log.info(f"Model is initialized with precision {self._dtype}.") + + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + model.compile(backend=backend) + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + """ + Set up the optimizer. This method also handles loading the optimizer state_dict, if specified. + """ + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + p: config.instantiate(cfg_optimizer, [p]) + for p in self._model.parameters() + } + # Register optimizer step hooks on the model to run optimizer in backward. + utils.register_optim_in_bwd_hooks(model=self._model, optim_dict=optim_dict) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states. If optimizer states are being restored in an optimizer in backward + # run, these need to have been saved with the same setting. Cannot restore from runs that did not + # use optimizer in backward. + if opt_state_dict is not None: + try: + self._optim_ckpt_wrapper.load_state_dict(opt_state_dict) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None, + ) + + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()} + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + ckpt_dict.update( + { + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + if not self._optimizer_in_bwd: + ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() + else: + ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + # NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently + # true since we don't expose the ability to configure this yet. + "lr": ( + self._optim_ckpt_wrapper.get_optim_key("lr") + if self._optimizer_in_bwd + else self._optimizer.param_groups[0]["lr"] + ), + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="FullFinetuneRecipeSingleDevice", cfg=cfg) + recipe = FullFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/generate.py b/training/recipes/generate.py new file mode 100644 index 0000000..8f094da --- /dev/null +++ b/training/recipes/generate.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from typing import Any, Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig +from torch import nn + +from torchtune import config, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import ChatFormat, InstructTemplate, Message + +logger = utils.get_logger("DEBUG") + + +class InferenceRecipe: + """ + Recipe for generating tokens from a dense Transformer-based LLM. + + Currently this recipe supports single-GPU generation only. Speculative + decoding is not supported. + + For more details on how to use this recipe for generation, please see our + tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation + + For using this recipe with a quantized model, please the following section of + the above tutorial: + https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype, device=self._device) + self._quantizer = config.instantiate(cfg.quantizer) + self._quantization_mode = utils.get_quantizer_mode(self._quantizer) + + utils.set_seed(seed=cfg.seed) + + def setup(self, cfg: DictConfig) -> None: + checkpointer = config.instantiate(cfg.checkpointer) + if self._quantization_mode is None: + ckpt_dict = checkpointer.load_checkpoint() + else: + # weights_only needs to be False when loading a quantized model + # currently loading a quantized model is only supported with the + # FullModelTorchTuneCheckpointer + ckpt_dict = checkpointer.load_checkpoint(weights_only=False) + + self._model = self._setup_model( + model_cfg=cfg.model, + model_state_dict=ckpt_dict[utils.MODEL_KEY], + enable_kv_cache=cfg.enable_kv_cache, + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + def _setup_model( + self, + model_cfg: DictConfig, + model_state_dict: Dict[str, Any], + enable_kv_cache: bool = True, + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + + if self._quantization_mode is not None: + model = self._quantizer.quantize(model) + model = model.to(device=self._device, dtype=self._dtype) + + model.load_state_dict(model_state_dict) + + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + logger.info(f"Model is initialized with precision {self._dtype}.") + + # Ensure the cache is setup on the right device + if enable_kv_cache: + with self._device: + model.setup_caches(batch_size=1, dtype=self._dtype) + + return model + + def convert_prompt_to_tokens( + self, + prompt: Union[DictConfig, str], + chat_format: Optional[ChatFormat], + instruct_template: Optional[InstructTemplate], + ) -> List[Message]: + """ + Either: + (1) a raw string is passed as the prompt, in which case we call tokenizer.encode directly, or + (2) a DictConfig is passed as the prompt. In this case there are three possibilities: + (a) an InstructTemplate is provided. Since instruct templates output a string, we will + call tokenizer.encode on the output of the instruct template. + (b) a ChatFormat is provided. Since chat formats output a list of messages, we will + call tokenizer.tokenize_messages on the output of the chat format. + (c) neither an InstructTemplate nor a ChatFormat is provided. In this case we will + convert the DictConfig to a list of messages and call tokenizer.tokenize_messages directly. + """ + + # Should only be chat-style prompt or instruct-style prompt + if chat_format and instruct_template: + raise ValueError( + "Cannot pass both chat format and instruct template for generation" + ) + + # If instruct template is provided, assert that the prompt is a DictConfig + # and apply it + if instruct_template: + if not isinstance(prompt, DictConfig): + raise ValueError("Cannot apply instruct template to raw string") + instruct_template = _get_component_from_path(instruct_template) + prompt = instruct_template.format(prompt) + + # To hit this block, either the raw prompt is a string or an + # instruct template has been provided to convert it to a string + if isinstance(prompt, str): + return self._tokenizer.encode(prompt, add_bos=True, add_eos=False) + + # dict.items() will respect order for Python >= 3.7 + else: + messages = [Message(role=k, content=v) for k, v in prompt.items()] + messages += [Message(role="assistant", content="")] + if chat_format: + chat_format = _get_component_from_path(chat_format) + messages = chat_format.format(messages) + return self._tokenizer.tokenize_messages(messages)[0] + + @torch.no_grad() + def generate(self, cfg: DictConfig): + tokens = self.convert_prompt_to_tokens( + cfg.prompt, cfg.get("chat_format", None), cfg.get("instruct_template", None) + ) + prompt = torch.tensor(tokens, dtype=torch.int, device=self._device) + + custom_generate_next_token = None + + # since quantized model uses torch.compile to get speedup, it needs a warm up / prefill run + # to get the accurate performance measurement + if self._quantization_mode is not None: + logger.info("Starting compilation to improve generation performance ...") + custom_generate_next_token = torch.compile( + utils.generate_next_token, mode="max-autotune", fullgraph=True + ) + t0 = time.perf_counter() + _ = utils.generate( + model=self._model, + prompt=prompt, + max_generated_tokens=2, + temperature=cfg.temperature, + top_k=cfg.top_k, + stop_tokens=self._tokenizer.stop_tokens, + custom_generate_next_token=custom_generate_next_token, + ) + t = time.perf_counter() - t0 + logger.info(f"Warmup run for quantized model takes: {t:.02f} sec") + + t0 = time.perf_counter() + generated_tokens = utils.generate( + model=self._model, + prompt=prompt, + max_generated_tokens=cfg.max_new_tokens, + temperature=cfg.temperature, + top_k=cfg.top_k, + stop_tokens=self._tokenizer.stop_tokens, + custom_generate_next_token=custom_generate_next_token, + ) + t = time.perf_counter() - t0 + + logger.info(self._tokenizer.decode(generated_tokens[0])) + + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain( + self._model.parameters(), self._model.buffers() + ) + ] + ) + + tokens_generated = len(generated_tokens[0]) - prompt.size(0) + tokens_sec = tokens_generated / t + logger.info( + f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) + logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +@config.parse +def main(cfg: DictConfig) -> None: + config.log_config(recipe_name="InferenceRecipe", cfg=cfg) + recipe = InferenceRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.generate(cfg=cfg) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/training/recipes/lora_dpo_distributed.py b/training/recipes/lora_dpo_distributed.py new file mode 100644 index 0000000..6982aae --- /dev/null +++ b/training/recipes/lora_dpo_distributed.py @@ -0,0 +1,732 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf +from torchtune.modules.loss import SimPOLoss +from torchtune.modules.peft.peft_utils import ( + disable_adapter, + get_adapter_params, + get_merged_lora_ckpt, + set_trainable_params, + validate_state_dict_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRADPORecipeDistributed(FTRecipeInterface): + """ + Distributed LoRA DPO recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). This is based on HF's DPOTrainer + in the TRL library: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L65 + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not + supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + The following losses are supported in this recipe: + - :class:`~torchtune.modules.loss.DPOLoss`: Direct Preference Optimization (DPO). + - :class:`~torchtune.modules.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). + - :class:`~torchtune.modules.loss.IPO`: Identity Preference Optimization (IPO). + - :class:`~torchtune.modules.loss.SimPOLoss`: Simple Preference Optimization (SimPO). + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + ValueError: If world_size is 1 + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + _, rank = utils.get_world_size_and_rank() + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # training attributes + self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + # When resuming from checkpoint for LoRA, the recipe expects the adapter weights + # and recipe state to be present. The keys should match up with what ``save_checkpoint`` + # used to create these intermediate checkpoints + if self._resume_from_checkpoint: + if utils.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + log.info("_metric_logger is initialized.") + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[utils.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None, + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss function is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. + """ + + if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + + log.info( + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + ) + + # The model contains LoRA params which won't have any matching keys in + # the state dict. As a result, we need to load with strict=False. + # Before loading the state dict, ensure the state dict keys for the base + # model and adapters (if available) match the keys in the full LoRA model + # This is a good sanity check to prevent silent errors + validate_state_dict_for_lora( + lora_attn_modules=cfg_model.lora_attn_modules, + apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, + apply_lora_to_output=cfg_model.apply_lora_to_output, + full_model_state_dict_keys=model.state_dict().keys(), + lora_state_dict_keys=( + lora_weights_state_dict.keys() + if lora_weights_state_dict is not None + else None + ), + base_model_state_dict_keys=base_model_state_dict.keys(), + ) + + # Load both the base model weights and (if available) the adapter weights. Both + # of this should happen only on Rank 0 + model.load_state_dict(base_model_state_dict, strict=False) + if lora_weights_state_dict: + model.load_state_dict(lora_weights_state_dict, strict=False) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # LoRA hyper-params needed for merging weights while saving checkpoints + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + + # Note: this needs to be set before wrapping with FSDP + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + model = FSDP( + module=model, + auto_wrap_policy=utils.lora_fsdp_wrap_policy( + modules_to_wrap={modules.TransformerDecoderLayer} + ), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: ( + module.to_empty(device=torch.device("cuda"), recurse=False) + if not self._is_rank_zero + else None + ) + ), + ) + + # Ensure no params and buffers are on meta device + utils.validate_no_params_on_meta_device(model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + # Note: technically we should check _contains_fsdp for + # just the state dict of the adapter cfg, but should be equivalent + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict + ) + optimizer.load_state_dict(opt_state_dict) + + if self._is_rank_zero: + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = utils.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + rlhf.padded_collate_dpo, + padding_idx=self._tokenizer.pad_id, + ignore_idx=CROSS_ENTROPY_IGNORE_IDX, + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + with FSDP.state_dict_type( + self._model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + if intermediate_checkpoint: + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) + } + checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def concatenated_forward( + self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward pass of the model with chosen and rejected samples concatenated. + + Args: + model (nn.Module): The model to be used for the forward pass. + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. + + Returns: + Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + """ + concatenated_input_ids, concatenated_labels = batch + concatenated_input_ids = concatenated_input_ids.to(self._device) + concatenated_labels = concatenated_labels.to(self._device) + + # formed by concatenating an equal number of "chosen" and "rejected". + len_chosen = concatenated_input_ids.shape[0] // 2 + + all_logits = model(concatenated_input_ids) + + all_log_probs = rlhf.get_batch_log_probs( + all_logits, + concatenated_labels, + # see :class:`~torchtune.modules.loss.dpo.SimPOLoss` + return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), + ) + + chosen_log_probs = all_log_probs[:len_chosen] + rejected_log_probs = all_log_probs[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + utils.cleanup_before_training() + + _, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # batch is input_ids, labels + num_tokens += batch[0].numel() + + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self._model, batch) + + if isinstance(self._loss_fn, SimPOLoss): + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, policy_rejected_log_probs + ) + else: + # reference based losses (e.g. DPO) explicitly regularize the objective fn based on + # the reference model's output - reference-free losses (such as SimPO) don't require this. + with torch.no_grad(), disable_adapter(self._model): + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = self.concatenated_forward(self._model, batch) + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + reference_chosen_log_probs, + reference_rejected_log_probs, + ) + + loss = loss.mean() + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + "rewards/chosen": chosen_rewards.mean().cpu(), + "rewards/rejected": rejected_rewards.mean().cpu(), + "rewards/accuracies": reward_accuracies.mean().cpu(), + "rewards/margins": (chosen_rewards - rejected_rewards) + .mean() + .cpu(), + "log_probs/rejected": policy_rejected_log_probs.detach() + .mean() + .cpu(), + "log_probs/chosen": policy_chosen_log_probs.detach() + .mean() + .cpu(), + "logits/rejected": policy_rejected_logits_mean.cpu(), + "logits/chosen": policy_chosen_logits_mean.cpu(), + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not utils.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="LoRADPORecipeDistributed", cfg=cfg) + + recipe = LoRADPORecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/lora_dpo_single_device.py b/training/recipes/lora_dpo_single_device.py new file mode 100644 index 0000000..c5f5bcf --- /dev/null +++ b/training/recipes/lora_dpo_single_device.py @@ -0,0 +1,616 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf + +from torchtune.modules.loss import SimPOLoss +from torchtune.modules.peft.peft_utils import ( + disable_adapter, + get_adapter_params, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, + validate_state_dict_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRADPORecipeSingleDevice(FTRecipeInterface): + """ + LoRA DPO recipe for dense transformer-based LLMs such as Llama2 for + single device training. This is based on HF's DPOTrainer in the + TRL library: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L65 + + This recipe supports: + - Activation checkpointing. This is enabled by default but is configurable. + - Full bf16 training for supported HW architectures. We currently check bf16 support via + the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via + setting `dtype=bf16` in configuration. + - Checkpointing: of LoRA adapter parameters and their optimizer states. When resuming + from a checkpoint, the adapter parameters are loaded from the checkpoint along + with the base model weights. Note that intra-epoch resumption is not supported. + - Logging to terminal, WandB, or TensorBoard. + + + The following losses are supported in this recipe: + - :class:`~torchtune.modules.loss.DPOLoss`: Direct Preference Optimization (DPO). + - :class:`~torchtune.modules.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). + - :class:`~torchtune.modules.loss.IPO`: Identity Preference Optimization (IPO). + - :class:`~torchtune.modules.loss.SimPOLoss`: Simple Preference Optimization (SimPO). + + Assumptions: + - Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done + in ongoing epoch is lost. + - Datasets are Map-style and data fits in memory (not streamed). + + The following configs can be used to run this recipe: + >>> tune ls + RECIPE CONFIG + lora_dpo_single_device llama2/7B_lora_dpo_single_device + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + # Reduced precision logic + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + # fp16 precision is explicitly disabled as it is not supported in this + # recipe (for example, no gradient scaling). + if self._dtype == torch.float16: + raise ValueError( + "fp16 precision is not supported in this recipe. Please use fp32 or bf16." + ) + # For CUDA devices, check if the HW supports bf16 if bf16 is specified. + if ( + self._dtype == torch.bfloat16 + and self._device != torch.device("cpu") + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError("Full bf16 training is not supported on this hardware.") + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + if utils.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not correctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + self._model_compile = cfg.compile + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=cfg.compile, + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[utils.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss function is initialized.") + + # Dataloader depends on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + validate_state_dict_for_lora( + lora_attn_modules=cfg_model.lora_attn_modules, + apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, + apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), + full_model_state_dict_keys=model.state_dict().keys(), + lora_state_dict_keys=( + lora_weights_state_dict.keys() + if lora_weights_state_dict is not None + else None + ), + base_model_state_dict_keys=base_model_state_dict.keys(), + ) + + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) + if lora_weights_state_dict: + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) + else: + lora_missing, lora_unexpected = None, None + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + + log.info(f"Model is initialized with precision {self._dtype}.") + + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + model.compile(backend=backend) + if self._device == torch.device("cuda"): + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports + Map-style Datasets which fit into memory and an option for random shuffling. + Samplers, iterable datasets, and streaming datasets are not supported. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=partial( + rlhf.padded_collate_dpo, + padding_idx=self._tokenizer.pad_id, + ignore_idx=CROSS_ENTROPY_IGNORE_IDX, + ), + ) + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # if training is in-progress, checkpoint the optimizer state as well + if intermediate_checkpoint: + ckpt_dict.update( + { + utils.OPT_KEY: self._optimizer.state_dict(), + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + # Construct the full state dict with LoRA weights merged into base LLM weights + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + ckpt_dict.update({utils.MODEL_KEY: merged_state_dict}) + + # Construct the adapter weights + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) + } + ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def concatenated_forward( + self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward pass of the model with chosen and rejected samples concatenated. + + Args: + model (nn.Module): The model to be used for the forward pass. + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. + + Returns: + Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + """ + concatenated_input_ids, concatenated_labels = batch + concatenated_input_ids = concatenated_input_ids.to(self._device) + concatenated_labels = concatenated_labels.to(self._device) + + # formed by concatenating an equal number of "chosen" and "rejected". + len_chosen = concatenated_input_ids.shape[0] // 2 + + all_logits = model(concatenated_input_ids) + + all_log_probs = rlhf.get_batch_log_probs( + all_logits, + concatenated_labels, + # see :class:`~torchtune.modules.loss.dpo.SimPOLoss` + return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), + ) + + chosen_log_probs = all_log_probs[:len_chosen] + rejected_log_probs = all_log_probs[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + + def train(self) -> None: + """ + The core training loop. + """ + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # batch is input_ids, labels + num_tokens += batch[0].numel() + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self._model, batch) + + policy_chosen_logits_mean = policy_chosen_logits.detach().mean() + policy_rejected_logits_mean = policy_rejected_logits.detach().mean() + + # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging + del policy_chosen_logits, policy_rejected_logits + + if isinstance(self._loss_fn, SimPOLoss): + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, policy_rejected_log_probs + ) + else: + # reference based losses (e.g. DPO) explicitly regularize the objective fn based on + # the reference model's output - reference-free losses (such as SimPO) don't require this. + with torch.no_grad(), disable_adapter(self._model): + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = self.concatenated_forward(self._model, batch) + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + reference_chosen_log_probs, + reference_rejected_log_probs, + ) + + loss = loss.mean() + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + "rewards/chosen": chosen_rewards.mean().cpu(), + "rewards/rejected": rejected_rewards.mean().cpu(), + "rewards/accuracies": reward_accuracies.mean().cpu(), + "rewards/margins": (chosen_rewards - rejected_rewards) + .mean() + .cpu(), + "log_probs/rejected": policy_rejected_log_probs.detach() + .mean() + .cpu(), + "log_probs/chosen": policy_chosen_log_probs.detach() + .mean() + .cpu(), + "logits/rejected": policy_rejected_logits_mean.cpu(), + "logits/chosen": policy_chosen_logits_mean.cpu(), + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="LoRADPORecipeSingleDevice", cfg=cfg) + recipe = LoRADPORecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/lora_finetune_distributed.py b/training/recipes/lora_finetune_distributed.py new file mode 100644 index 0000000..c40f041 --- /dev/null +++ b/training/recipes/lora_finetune_distributed.py @@ -0,0 +1,760 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft.peft_utils import ( + get_adapter_params, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_state_dict_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.utils import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRAFinetuneRecipeDistributed(FTRecipeInterface): + """ + Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not + supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + ValueError: If world_size is 1 + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + _, rank = utils.get_world_size_and_rank() + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # training attributes + self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + # When resuming from checkpoint for LoRA, the recipe expects the adapter weights + # and recipe state to be present. The keys should match up with what ``save_checkpoint`` + # used to create these intermediate checkpoints + if self._resume_from_checkpoint: + if utils.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[utils.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None, + ) + + self._loss_fn = config.instantiate(cfg.loss) + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. + """ + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + + log.info( + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + ) + + # The model contains LoRA params which won't have any matching keys in + # the state dict. As a result, we need to load with strict=False. + # Before loading the state dict, ensure the state dict keys for the base + # model and adapters (if available) match the keys in the full LoRA model + # This is a good sanity check to prevent silent errors + validate_state_dict_for_lora( + lora_attn_modules=cfg_model.lora_attn_modules, + apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, + apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), + full_model_state_dict_keys=model.state_dict().keys(), + lora_state_dict_keys=( + lora_weights_state_dict.keys() + if lora_weights_state_dict is not None + else None + ), + base_model_state_dict_keys=base_model_state_dict.keys(), + ) + + # Load both the base model weights and (if available) the adapter weights. Both + # of this should happen only on Rank 0 + model.load_state_dict(base_model_state_dict, strict=False) + if lora_weights_state_dict: + model.load_state_dict(lora_weights_state_dict, strict=False) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # LoRA hyper-params needed for merging weights while saving checkpoints + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + + # Note: this needs to be set before wrapping with FSDP + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + model = FSDP( + module=model, + auto_wrap_policy=utils.lora_fsdp_wrap_policy( + modules_to_wrap={modules.TransformerDecoderLayer} + ), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: ( + module.to_empty(device=torch.device("cuda"), recurse=False) + if not self._is_rank_zero + else None + ) + ), + ) + + # Ensure no params and buffers are on meta device + utils.validate_no_params_on_meta_device(model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + # Note: technically we should check _contains_fsdp for + # just the state dict of the adapter cfg, but should be equivalent + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict + ) + optimizer.load_state_dict(opt_state_dict) + + if self._is_rank_zero: + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = utils.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=( + partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + with FSDP.state_dict_type( + self._model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + if intermediate_checkpoint: + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) + } + checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + checkpoint_dict.update({utils.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + utils.cleanup_before_training() + + _, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not utils.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) + + recipe = LoRAFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/lora_finetune_single_device.py b/training/recipes/lora_finetune_single_device.py new file mode 100644 index 0000000..653fdc7 --- /dev/null +++ b/training/recipes/lora_finetune_single_device.py @@ -0,0 +1,657 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft.peft_utils import ( + get_adapter_params, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.utils import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported.g + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with + 8-bit AdamW and Paged AdamW. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + # Reduced precision logic + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + # fp16 precision is explicitly disabled as it is not supported in this + # recipe (for example, no gradient scaling). + if self._dtype == torch.float16: + raise ValueError( + "fp16 precision is not supported in this recipe. Please use fp32 or bf16." + ) + # For CUDA devices, check if the HW supports bf16 if bf16 is specified. + if ( + self._dtype == torch.bfloat16 + and self._device != torch.device("cpu") + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError("Full bf16 training is not supported on this hardware.") + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + if utils.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + self._model_compile = cfg.compile + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=cfg.compile, + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[utils.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # Dataloader depends on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) + if lora_weights_state_dict: + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) + else: + lora_missing, lora_unexpected = None, None + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Validate model adapter params were loaded in with the expected dtype + # TODO (rohan-varma): Further validation to ensure the appropriate base params + # are NF4 vs bf16 based on the quantization config. + utils.validate_expected_param_dtype( + self.adapter_params.items(), dtype=self._dtype + ) + + log.info(f"Model is initialized with precision {self._dtype}.") + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + model.compile(backend=backend) + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports + Map-style Datasets which fit into memory and an option for random shuffling. + Samplers, iterable datasets, and streaming datasets are not supported. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=( + partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None + ), + ) + + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # if training is in-progress, checkpoint the optimizer state as well + if intermediate_checkpoint: + ckpt_dict.update( + { + utils.OPT_KEY: self._optimizer.state_dict(), + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + # Construct the full state dict with LoRA weights merged into base LLM weights + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + ckpt_dict.update({utils.MODEL_KEY: merged_state_dict}) + + # Construct the adapter weights + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) + } + ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + ckpt_dict.update({utils.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def train(self) -> None: + """ + The core training loop. + """ + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + with self._profiler as prof: + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if ( + self._device.type == "cuda" + and self._log_peak_memory_stats + ): + log_dict.update( + utils.get_memory_stats(device=self._device) + ) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + prof.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="LoRAFinetuneRecipeSingleDevice", cfg=cfg) + recipe = LoRAFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/ppo_full_finetune_single_device.py b/training/recipes/ppo_full_finetune_single_device.py new file mode 100644 index 0000000..8ee77c0 --- /dev/null +++ b/training/recipes/ppo_full_finetune_single_device.py @@ -0,0 +1,1084 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import sys +from functools import partial +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf +from torchtune.modules.rlhf import PPOStats, Trajectory +from torchtune.recipe_interfaces import FTRecipeInterface +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + Full finetuning recipe for RLHF with PPO for dense transformer-based LLMs such as LLama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + This implementation is based on `Learning to summarize from human feedback None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._dtype == torch.float16: + raise RuntimeError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + # manually setting up a generator for the recipe + self._rng = torch.Generator(self._device).manual_seed(self.seed) + self._total_steps = 0 + self._steps_run = 0 + self._total_epochs = 0 + self._epochs_run = 0 + self.global_step = 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + # setup checkpointers + ( + self._policy_checkpointer, + ref_policy_checkpointer, + self._value_checkpointer, + reward_checkpointer, + ) = self._setup_checkpointers( + cfg.checkpointer, + cfg.ref_policy_checkpointer, + cfg.value_checkpointer, + cfg.reward_checkpointer, + ) + + # load policy checkpoints + policy_model_checkpoint_dict = self._policy_checkpointer.load_checkpoint() + ref_policy_state_dict = ref_policy_checkpointer.load_checkpoint() + + # load reward and value model checkpoints + value_model_checkpoint_dict = self._value_checkpointer.load_checkpoint() + reward_model_state_dict = reward_checkpointer.load_checkpoint() + + # update recipe state + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model_compile = cfg.compile + self._optimizer_in_bwd = cfg.optimizer_in_bwd + ( + self._policy_model, + self._value_model, + self._reward_model, + self._ref_policy_model, + ) = self._setup_model( + cfg_model=cfg.policy_model, + cfg_reward_value_model=cfg.reward_and_value_model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=self._model_compile, + policy_state_dict=policy_model_checkpoint_dict[utils.MODEL_KEY], + ref_policy_state_dict=ref_policy_state_dict[utils.MODEL_KEY], + value_model_state_dict=value_model_checkpoint_dict[utils.MODEL_KEY], + reward_model_state_dict=reward_model_state_dict[utils.MODEL_KEY], + ) + + # setup tokenizer + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=cfg.optimizer_in_bwd, + opt_state_dict=( + policy_model_checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # sampler and dataloader depends on the tokenizer and should be set + # setup afterit is initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + self._setup_training_parameters(cfg) + self._setup_training_hyperparameters(cfg) + + if self._resume_from_checkpoint: + self._update_recipe_state(policy_model_checkpoint_dict) + + # one "step" is a single gradient update update over a minibatch of trajectories + self.global_step = ( + self._steps_run + * self._ppo_epochs + * (self.batch_size // self._ppo_batch_size) + ) + + def _setup_training_hyperparameters(self, cfg) -> None: + """ + Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters, + generation hyperparameters, reward masking hyperparameters, and stop token ids. + """ + + self._kl_coeff = cfg.kl_coeff + # GAE hyperparameters + self._gamma = cfg.gamma + self._lmbda = cfg.lmbda + self._whiten_rewards = cfg.whiten_rewards + + # trajectory generation args + self._temperature = cfg.temperature + self._top_k = cfg.top_k + self._max_generated_tokens = cfg.max_generated_tokens + + # reward masking args + self._min_response_length = cfg.min_response_length + self._penalise_no_eos = cfg.penalise_no_eos + self._reward_penalty = cfg.reward_penalty + + # lots of hand holding for stop tokens + if cfg.get("stop_token_ids", False): + stop_token_ids = cfg.stop_token_ids + if self._tokenizer.eos_id not in stop_token_ids: + warn( + f"tokenizer eos_id ({self._tokenizer.eos_id}) is not in stop_token_ids ({stop_token_ids})." + "This may lead to unexpected behaviour." + ) + else: + if not hasattr(self._tokenizer.stop_tokens): + warn( + "No stop tokens defined in tokenizer, and no stop_token_ids provided. This may lead to unexpected behaviour." + ) + stop_token_ids = [] + else: + stop_token_ids = self._tokenizer.stop_tokens + self._stop_token_ids = torch.tensor(stop_token_ids, device=self._device) + + def _setup_training_parameters(self, cfg: DictConfig) -> None: + """ + Validates and sets up parameters for used during training and for tracking training state, + batch sizes for model forward passes during trajectory generation, PPO minibatches, and + PPO microbatches for gradient accumulation. + + Raises + - ValueError if: + - batch_size is not divisible by forward_batch_size + - batch_size is not divisible by ppo_batch_size + - ppo_batch_size is not divisible by gradient_accumulation_steps + - num_steps is less than batch_size + - gradient_accumulation_steps > 1 and optimizer_in_bwd is True + """ + self.batch_size = cfg.batch_size + self._forward_batch_size = cfg.forward_batch_size + self._ppo_epochs = cfg.ppo_epochs + self._ppo_batch_size = cfg.ppo_batch_size + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._ppo_backward_batch_size = ( + cfg.ppo_batch_size // self._gradient_accumulation_steps + ) + + if self.batch_size % self._forward_batch_size != 0: + raise ValueError( + f"batch_size ({self.batch_size}) must be exactly divisible by " + f"forward_batch_size ({self._forward_batch_size})." + ) + if self.batch_size % self._ppo_batch_size != 0: + raise ValueError( + f"batch_size ({self.batch_size}) must be exactly divisible by " + f"ppo_batch_size ({self._ppo_batch_size})." + ) + if self._ppo_batch_size % self._gradient_accumulation_steps != 0: + raise ValueError( + f"ppo_batch_size ({self._ppo_batch_size}) must be exactly divisible " + f"by gradient_accumulation_steps ({self._gradient_accumulation_steps})." + ) + + if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + self._total_steps = cfg.num_steps // self.batch_size + batches_per_epoch = max( + 1, len(self._dataloader) + ) # when we only have a single batch in the dataset + + self._total_epochs = math.ceil(self._total_steps / batches_per_epoch) + if self._total_steps == 0: + raise ValueError( + f"num_steps {cfg.num_steps} must be greater than the batch size {self.batch_size}." + ) + if self._total_steps < len(self._dataloader): + warn( + f"There are fewer total steps ({self._total_steps}, (num_steps//batch_size) " + f"than there are batches ({len(self._dataloader)}) in the dataset. " + f"Training will stop after ({self._total_steps}) steps without saving intermediate checkpoints" + ) + if (self._total_steps > batches_per_epoch) and ( + self._total_steps % batches_per_epoch != 0 + ): + warn( + f"num_steps ({cfg.num_steps}) is not exactly divisible by " + f"the number of batches in the dataset ({batches_per_epoch}). " + f"Intermediate checkpoints will only be saved every {batches_per_epoch} steps." + ) + log.info( + f"Total steps to run: {self._total_steps}, Total epochs to run: {self._total_epochs}" + ) + + def _setup_checkpointers( + self, + policy_cfg: DictConfig, + ref_policy_cfg: DictConfig, + value_cfg: DictConfig, + reward_cfg: DictConfig, + ) -> Tuple[ + utils.Checkpointer, utils.Checkpointer, utils.Checkpointer, utils.Checkpointer + ]: + """ + Sets up checkpointers for policy, reference policy, value, and reward models. + Only the policy checkpoint handles recipe state for resuming from checkpoints. + """ + + if not self._resume_from_checkpoint: + assert policy_cfg.checkpoint_dir == ref_policy_cfg.checkpoint_dir, ( + "Policy and reference policy should be loaded from the same checkpoint directories" + f"at the start of training. Found: {policy_cfg.checkpoint_dir} and" + f"{ref_policy_cfg.checkpoint_dir}" + ) + assert policy_cfg.checkpoint_files == ref_policy_cfg.checkpoint_files, ( + "Policy and reference policy should be loaded from the same checkpoint files" + f"at the start of training. Found: {policy_cfg.checkpoint_files} and" + f"{ref_policy_cfg.checkpoint_files}" + ) + + policy_checkpointer = config.instantiate( + policy_cfg, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + + ref_policy_checkpointer = config.instantiate( + ref_policy_cfg, + resume_from_checkpoint=False, + ) + + value_checkpointer = config.instantiate( + value_cfg, + resume_from_checkpoint=False, + ) + + reward_checkpointer = config.instantiate( + reward_cfg, + resume_from_checkpoint=False, + ) + + return ( + policy_checkpointer, + ref_policy_checkpointer, + value_checkpointer, + reward_checkpointer, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + cfg_reward_value_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + policy_state_dict: Dict[str, Any], + ref_policy_state_dict: Dict[str, Any], + value_model_state_dict: Dict[str, Any], + reward_model_state_dict: Dict[str, Any], + ) -> Tuple[nn.Module, nn.Module, nn.Module]: + """ + Sets up the policy model, reference policy model, reward model, and value model. + """ + + with utils.set_default_dtype(self._dtype), self._device: + policy_model = config.instantiate(cfg_model) + ref_policy_model = config.instantiate(cfg_model) + reward_model = config.instantiate(cfg_reward_value_model) + value_model = config.instantiate(cfg_reward_value_model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + policy_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + utils.set_activation_checkpointing( + value_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + policy_model.load_state_dict(policy_state_dict) + ref_policy_model.load_state_dict(ref_policy_state_dict) + + reward_missing, reward_unexpected = reward_model.load_state_dict( + reward_model_state_dict, strict=False + ) + value_missing, value_unexpected = value_model.load_state_dict( + value_model_state_dict, strict=False + ) + + # some extra validation for HF classifier checkpoints with a `score.bias` present + assert ( + reward_missing == value_missing == [] + ), f"Missing keys in reward ({reward_missing}) and value model ({value_missing}) state dicts." + + if reward_unexpected or value_unexpected: + # the only unexpected keys should be when pre-trained HF models were saved with + # bias=True in final classification layers. This happens when training a reward model with TRL. + assert ( + reward_unexpected == value_unexpected == ["output.bias"] + ), f"Unexpected keys in reward ({reward_unexpected}) and value model ({value_unexpected}) state dicts." + + # Validate models were loaded in with the expected dtype. + utils.validate_expected_param_dtype( + value_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + reward_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + value_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + ref_policy_model.named_parameters(), dtype=self._dtype + ) + + log.info(f"Models are initialized with precision {self._dtype}.") + + # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs + # between ref policy and current policy + for module in policy_model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + for module in value_model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + + # disabling grad and dropout in reward and reference policy models + reward_model.eval() + ref_policy_model.eval() + + for p in reward_model.parameters(): + p.requires_grad = False + + for p in ref_policy_model.parameters(): + p.requires_grad = False + + # Compile model, if enabled. + if compile_model: + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + log.info("Compiling models with torch.compile...") + + policy_model.compile(backend=backend) + reward_model.compile(backend=backend) + ref_policy_model.compile(backend=backend) + value_model.compile(backend=backend) + + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + return policy_model, value_model, reward_model, ref_policy_model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optimizer: + + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + p: config.instantiate(cfg_optimizer, [p]) + for p in chain( + self._policy_model.parameters(), self._value_model.parameters() + ) + } + # Register optimizer step hooks on the models to run optimizer in backward. + utils.register_optim_in_bwd_hooks( + model=self._policy_model, optim_dict=optim_dict + ) + utils.register_optim_in_bwd_hooks( + model=self._value_model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._policy_model, optim_dict=optim_dict + ) + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._value_model, optim_dict=optim_dict + ) + # Load optimizer states. If optimizer states are being restored in an optimizer in backward + # run, these need to have been saved with the same setting. Cannot restore from runs that did not + # use optimizer in backward. + if opt_state_dict is not None: + try: + self._optim_ckpt_wrapper.load_state_dict(opt_state_dict) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate( + cfg_optimizer, + chain(self._policy_model.parameters(), self._value_model.parameters()), + ) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=partial( + rlhf.left_padded_collate, + padding_idx=self._tokenizer.pad_id, + ), + drop_last=True, + ) + + return sampler, dataloader + + def save_checkpoint( + self, epoch: int, is_intermediate_checkpoint: bool = False + ) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + policy_ckpt_dict = {utils.MODEL_KEY: self._policy_model.state_dict()} + value_ckpt_dict = {utils.MODEL_KEY: self._value_model.state_dict()} + + # if training is in-progress, checkpoint the optimizer state and rng state as well + if is_intermediate_checkpoint: + policy_ckpt_dict.update( + { + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self._epochs_run, + utils.TOTAL_EPOCHS_KEY: self._total_epochs, + utils.MAX_STEPS_KEY: self._total_steps, + utils.STEPS_KEY: self._steps_run, + utils.RNG_KEY: self._rng.get_state(), + } + ) + if not self._optimizer_in_bwd: + policy_ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() + else: + policy_ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + + self._policy_checkpointer.save_checkpoint( + policy_ckpt_dict, + epoch=epoch, + intermediate_checkpoint=is_intermediate_checkpoint, + ) + + self._value_checkpointer.save_checkpoint( + value_ckpt_dict, + epoch=epoch, + intermediate_checkpoint=False, + ) + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + # If seed or total_steps, or total_epochs don't match, + # warn the user and overwrite. + try: + if ( + self.seed != ckpt_dict[utils.SEED_KEY] + or self._total_steps != ckpt_dict[utils.MAX_STEPS_KEY] + or self._total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + ): + warn( + message="""Configured value for seed, total_steps, or total_epochs + does not match the value stored in checkpoint.""" + ) + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self._rng.set_state(ckpt_dict[utils.RNG_KEY]) + self._steps_run = ckpt_dict[utils.STEPS_KEY] + self._total_steps = ckpt_dict[utils.MAX_STEPS_KEY] + self._total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self._epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + except KeyError as e: + raise KeyError from e( + "Checkpoint does not contain the required keys needed for updating recipe state." + "Are you sure you passed in the right recipe checkpoint?" + ) + + def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: + """ + Generates a trajectory given the current policy and value models, the reference policy model, the reward model, + and batch of inputs. This is done over the following steps: + + 1: Generate responses, and logits corresponding to the responses using the current policy, + generating (query, response) pairs. + 2. Estimate logprobs of the generated responses using the current policy. + 3. Estimate values from the generated responses using the current value function. + 4. Replace any tokens in the response after the first stop token (usually EOS token) with padding, + producting truncated responses. + 5. Run the reward model on the (query, truncated-response) pairs. + 6. Mask out all the invalid values in the trajectory due to padding tokens. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + + Returns: + Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory` comprising + the current trajectory. + """ + batch_size, context_length = input_ids.shape + + # step 1: generate responses, and logits corresponding to the responses using the current policy + query_responses, logits = rlhf.generate_with_logits( + model=self._policy_model, + prompt=input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + ) + + responses = query_responses[:, context_length:].clone() + query_response_padding_masks = query_responses == self._tokenizer.pad_id + + # step 1.1 create attention masks and position IDs for any padding tokens in inputs, used for future forward passes + masks = rlhf.get_causal_mask(~(query_response_padding_masks)) + position_ids = (~query_response_padding_masks).cumsum(-1) - ( + ~query_response_padding_masks + ).long() + position_ids = position_ids.type(torch.int) + + del query_response_padding_masks + + # step 2. estimate logprobs of the responses using the current policy + logits = logits[:, context_length - 1 :] + logprobs = rlhf.logits_to_logprobs(logits, responses, self._temperature) + + del logits + + # step 2.1 estimate logprobs of the responses using the reference policy + ref_logits = self._ref_policy_model( + query_responses, input_pos=position_ids, mask=masks + ) + ref_logits = rlhf.truncate_sequence_for_logprobs(ref_logits, context_length) + ref_logprobs = rlhf.logits_to_logprobs(ref_logits, responses, self._temperature) + + del ref_logits + + # step 3. estimate values from the responses using the value function + values = self._value_model(query_responses, input_pos=position_ids, mask=masks) + values = rlhf.truncate_sequence_for_logprobs(values, context_length).squeeze(-1) + + # step 4. replace any tokens in the responses after the first stop token (usually EOS token) with padding + # resulting in truncated responses + response_padding_masks, responses = rlhf.truncate_sequence_at_first_stop_token( + responses, self._stop_token_ids, self._tokenizer.pad_id + ) + + # step 5. run the reward model on the (query, truncated-response) pairs + scores = self._reward_model( + torch.cat([input_ids, responses], dim=1), + input_pos=position_ids, + mask=masks, + ) + + del responses + + # step 5.1 the scores from the reward model are the logits for the last non-padding token in + # each (query, truncated-response) pair + seq_lens = utils.get_unmasked_sequence_lengths(response_padding_masks) + scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) + + # step 5.2 if configured, apply any penalties for sequences without EOS tokens + # or shorter than a certain length + if self._penalise_no_eos or self._min_response_length: + reward_penalty_mask = rlhf.get_reward_penalty_mask( + response_padding_masks, + seq_lens, + self._penalise_no_eos, + self._min_response_length, + ) + scores[reward_penalty_mask] = self._reward_penalty + + # step 6. mask out all the invalid values in the trajectory due to padding tokens + logprobs[response_padding_masks] = 1.0 + ref_logprobs[response_padding_masks] = 1.0 + + # step 6.1 values are masked out *after* the last valid token in the response + value_seq_idxs = torch.where( + (seq_lens > 0) & (seq_lens < self._max_generated_tokens - 1), + seq_lens + 1, + seq_lens, + ) + value_padding_masks = response_padding_masks.clone() + value_padding_masks[ + torch.arange(batch_size, device=value_padding_masks.device), + value_seq_idxs, + ] = False + + values[value_padding_masks] = 0.0 + + return Trajectory( + query_responses=query_responses, + logprobs=logprobs, + ref_logprobs=ref_logprobs, + values=values, + masks=masks, + position_ids=position_ids, + response_padding_masks=response_padding_masks, + value_padding_masks=value_padding_masks, + value_seq_idxs=value_seq_idxs, + scores=scores, + seq_lens=seq_lens, + ) + + def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: + """ + Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. + See ``generate_trajectory`` for more details. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + + Returns: + Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory`, comprising + the current trajectory. + """ + trajectories: List[Trajectory] = [] + with torch.no_grad(): + for batch_start in range(0, self.batch_size, self._forward_batch_size): + batch_input_ids = input_ids[ + batch_start : batch_start + self._forward_batch_size + ] + trajectories.append(self.generate_trajectory(batch_input_ids)) + return Trajectory(*map(torch.cat, zip(*trajectories))) + + def train(self) -> None: + """ + The core training loop.""" + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward." + "Expect a relatively slow first iteration." + ) + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + + training_completed = False + pbar = tqdm(total=self._total_steps, initial=self._steps_run) + for curr_epoch in range(self._epochs_run, self._total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + for _, batch in enumerate(self._dataloader): + batch = batch.to(self._device) + _, context_length = batch.shape + + # step 1. generate the trajectory using: + # - the current policy (pi_theta) + # - the current value function (V_phi) + # - the reference frozen policy model (pi_theta_0) + trajectory = self.generate_trajectory_batched(batch) + + # step 2. get the rewards for the current trajectory. these are based on: + # - the divergence between the current policy and the reference policy + # - the scores from the reward model + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + trajectory.scores, + trajectory.logprobs, + trajectory.ref_logprobs, + self._kl_coeff, + trajectory.value_seq_idxs, + ) + + # step 3. estimate the advantages using Generalized Advantage Estimation (GAE) + advantages, returns = rlhf.estimate_advantages( + trajectory.values, + rewards, + self._gamma, + self._lmbda, + masks=~trajectory.response_padding_masks, + ) + + # step 4. optimise using the PPO objective over multiple epochs + ppo_stats: List[PPOStats] = [] + for _ in range(self._ppo_epochs): + batch_idxs = torch.randperm(self.batch_size, device=self._device) + for i in range(0, self.batch_size, self._ppo_batch_size): + mini_batch_idxs = batch_idxs[i : i + self._ppo_batch_size] + + batch_ppo_stats: List[PPOStats] = [] + for j in range( + 0, self._ppo_batch_size, self._ppo_backward_batch_size + ): + backward_batch_idxs = mini_batch_idxs[ + j : j + self._ppo_backward_batch_size + ] + + batch_trajectory = Trajectory( + *map( + partial( + torch.index_select, + dim=0, + index=backward_batch_idxs, + ), + trajectory, + ) + ) + batch_ppo_stats.append( + self._ppo_step( + batch_trajectory, + advantages[backward_batch_idxs], + returns[backward_batch_idxs], + context_length, + ) + ) + del batch_trajectory + + ppo_stats.append(PPOStats(*map(sum, zip(*batch_ppo_stats)))) + + if not self._optimizer_in_bwd: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + self.global_step += 1 + + # step 5. profit + self._steps_run += 1 + if self._steps_run % self._log_every_n_steps == 0: + self.log_metrics( + trajectory, + PPOStats(*map(torch.stack, zip(*ppo_stats))), + kl, + kl_rewards, + ) + self.cleanup_after_step( + trajectory, ppo_stats, advantages, returns, kl, kl_rewards + ) + pbar.update(1) + if self._steps_run == self._total_steps: + training_completed = True + break + + # save checkpoint at current epoch + self._epochs_run += 1 + + self.save_checkpoint( + curr_epoch, is_intermediate_checkpoint=not training_completed + ) + if training_completed: + return + + def _ppo_step( + self, + trajectory: Trajectory, + advantages: torch.Tensor, + returns: torch.Tensor, + context_length: int, + ) -> PPOStats: + """ + Perform a single PPO optimisation step over a batch of trajectories and corresponding advantages and returns. + + Args: + trajectory (Trajectory): a batch of trajectories + advantages (torch.Tensor): advantages corresponding to the trajectories + returns (torch.Tensor): returns corresponding the trajectories + context_length (int): input ids sequence length + + Returns: + PPOStats: An instance of :class:`~torchtune.modules.rlhf.PPOStats`, a NamedTuple containing: + - loss (torch.Tensor): The total PPO loss. + - policy_loss (torch.Tensor): The policy function loss. + - value_loss (torch.Tensor): The value function loss. + - ratios (torch.Tensor): The ratio between the current and old policy probabilities. + - clipfrac (torch.Tensor): The fraction of ratios that were clipped. + - approx_policy_kls: Average estimated KL divergence between the policy before and after the optimisation step. + + """ + # estimate logprobs from the policy at the current optimisation step + pi_logits = self._policy_model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + pi_logits = rlhf.truncate_sequence_for_logprobs(pi_logits, context_length) + pi_logprobs = rlhf.logits_to_logprobs( + pi_logits, trajectory.query_responses[:, context_length:], self._temperature + ) + pi_logprobs[trajectory.response_padding_masks] = 1.0 + + del pi_logits + + # estimate the values from the value function at the current optimisation step + phi_values = self._value_model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + + phi_values = rlhf.truncate_sequence_for_logprobs( + phi_values, context_length + ).squeeze(-1) + phi_values[trajectory.value_padding_masks] = 0.0 + + # calculate ppo loss + loss, policy_loss, value_loss, ratios, clipfrac = self._loss_fn( + trajectory.logprobs, + pi_logprobs, + advantages, + trajectory.values, + phi_values, + returns, + padding_masks=~trajectory.response_padding_masks, + value_padding_masks=~trajectory.value_padding_masks, + ) + + loss /= self._gradient_accumulation_steps + loss.backward() + + with torch.no_grad(): + approx_policy_kls = ( + 0.5 * (pi_logprobs - trajectory.logprobs).pow(2) + ).mean() + + return PPOStats( + loss, + policy_loss / self._gradient_accumulation_steps, + value_loss / self._gradient_accumulation_steps, + ratios / self._gradient_accumulation_steps, + clipfrac / self._gradient_accumulation_steps, + approx_policy_kls / self._gradient_accumulation_steps, + ) + + def log_metrics( + self, + trajectory: Trajectory, + ppo_stats: PPOStats, + kl: torch.Tensor, + kl_rewards: torch.Tensor, + ) -> None: + """ + Log metrics and statistics for the current step to the metric logger. + """ + log_dict = { + "scores": trajectory.scores.mean(), + "num_stop_tokens": trajectory.response_padding_masks.any(-1).sum(), + "rlhf_reward": trajectory.scores.mean() + kl_rewards.sum(1).mean(), + "kl": kl.sum(1).mean(), + "kl_reward": kl_rewards.sum(1).mean(), + "loss": ppo_stats.loss.mean(), + "policy_loss": ppo_stats.policy_loss.mean(), + "value_loss": ppo_stats.value_loss.mean(), + "clipfrac": ppo_stats.clipfrac.mean(), + "ratios": ppo_stats.ratios.mean(), + "approx_policy_kl": ppo_stats.approx_policy_kls.mean(), + "response_lengths": trajectory.seq_lens.float().mean(), + } + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + + self._metric_logger.log_dict(log_dict, step=self.global_step) + + def cleanup_after_step( + self, + trajectory: Trajectory, + ppo_stats: PPOStats, + advantages: torch.Tensor, + returns: torch.Tensor, + kl: torch.Tensor, + kl_rewards: torch.Tensor, + ) -> None: + """ + Cleanup tensors after each PPO step to free up memory. + """ + # there shouldn't be any floating references to the individual tensors at the this point, so gc can do its thing + for v in trajectory: + del v + del trajectory + for v in ppo_stats: + del v + del ppo_stats + del advantages + del returns + del kl + del kl_rewards + + def cleanup(self, **kwargs) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="PPOFullFinetuneRecipeSingleDevice", cfg=cfg) + recipe = PPOFullFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/qat_distributed.py b/training/recipes/qat_distributed.py new file mode 100644 index 0000000..549a9f6 --- /dev/null +++ b/training/recipes/qat_distributed.py @@ -0,0 +1,658 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import init_process_group +from torch.distributed.fsdp import ( + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.utils.activations import apply_selective_activation_checkpointing + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class QATRecipeDistributed(FTRecipeInterface): + """ + Quantization-aware training (QAT) recipe for dense transformer-based LLMs such as Llama2. + This recipe supports distributed training and can be run on a single node (1 to 8 GPUs). + Only compatible with PyTorch 2.4+. + + Features: + - Quantization-aware training (QAT). Perform fake quantization on weights and/or activations + during finetuning, with the goal of ultimately producing a quantized model with minimal + accuracy degradation. This recipe produces an unquantized model in the original dtype, + which can then be quantized separately. + + - Delayed fake quantization. Optionally specify the step after which fake quantization occurs. + Empirically, allowing the model to finetune without fake quantization initially allows the + weight and activation values to stabilize before fake quantizing them, potentially leading + to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. + + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU + is not supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = utils.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) + self._quantizer_mode = None + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + quantizer_cfg=cfg.get("quantizer", None), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=ckpt_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None, + ) + + self._loss_fn = config.instantiate(cfg.loss) + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + memory_efficient_fsdp_wrap: bool, + fsdp_cpu_offload: bool, + model_state_dict: Dict[str, Any], + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + quantizer_cfg: Optional[DictConfig] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. + """ + if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + + log.info( + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + ) + + # Load both the model weights. This should happen only on Rank 0 + model.load_state_dict(model_state_dict) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + ac_mode = ac_mode + ac_option = ac_option + + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # Apply quantization-aware training during finetuning + if quantizer_cfg is None: + raise ValueError("Quantizer must be specified for QAT recipe.") + quantizer = config.instantiate(quantizer_cfg) + quantizer.precision = self._dtype + quantizer_mode = utils.quantization.get_quantizer_mode(quantizer) + if "qat" not in quantizer_mode: + raise ValueError( + "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode + ) + self._quantizer_mode = quantizer_mode + model = quantizer.prepare(model) + + # Wrap the model with FSDP. This will ensure that the model is sharded + # across all available GPUs. + model = FSDP( + module=model, + auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, + modules_to_wrap={modules.TransformerDecoderLayer}, + ), + cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if not self._is_rank_zero + else None + ), + ) + + # Ensure no params and buffers are on meta device + utils.validate_no_params_on_meta_device(model) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + """ + Set up the optimizer. This method also handles transforing the state dict + for FSDP. + """ + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + + if opt_state_dict: + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict + ) + optimizer.load_state_dict(opt_state_dict) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = utils.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None, + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + checkpoint_dict = {} + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + with FSDP.state_dict_type( + self._model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + checkpoint_dict.update( + { + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + # clean up before training begins + utils.cleanup_before_training() + + _, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + # Optionally wait N steps before enabling fake quant + if self._fake_quant_after_n_steps is not None: + if self.global_step == 0: + log.info( + "Step 0: Disabling fake quant, will re-enable in step %s" + % self._fake_quant_after_n_steps + ) + disable_fq = utils.quantization._get_disable_fake_quant( + self._quantizer_mode + ) + self._model.apply(disable_fq) + elif self.global_step == self._fake_quant_after_n_steps: + log.info( + "Step %s: Enabling fake quant" + % self._fake_quant_after_n_steps + ) + enable_fq = utils.quantization._get_enable_fake_quant( + self._quantizer_mode + ) + self._model.apply(enable_fq) + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + torch.distributed.destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not utils.is_distributed(): + raise RuntimeError( + "Distributed QAT recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + utils.set_torch_num_threads() + + config.log_config(recipe_name="QATRecipeDistributed", cfg=cfg) + + recipe = QATRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/training/recipes/quantization.md b/training/recipes/quantization.md new file mode 100644 index 0000000..29bd5cb --- /dev/null +++ b/training/recipes/quantization.md @@ -0,0 +1,85 @@ +# Quantization and Sparsity + +torchtune integrates with [torchao](https://github.com/pytorch/ao/) for QAT and QLoRA. Currently only some quantization techniques are integrated, see the docstrings in the [quantization recipe](quantize.py) and the [QAT recipe](qat_distributed.py) for more details. + +For post training quantization, we recommend using `torchao` directly: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md to quantize their model +and do eval/benchmark in torchao as well: https://github.com/pytorch/ao/tree/main/torchao/_models/llama. + +## Quantization-Aware Training (QAT) + +(PyTorch 2.4+) + +QAT refers to applying fake quantization to weights and/or activations during finetuning, +which means simulating only the quantization math without actually casting the original +dtype to a lower precision. You can run QAT with finetuning using the following command: + +``` +tune run --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full +``` + +This produces an unquantized model in the original data type. To get an actual quantized model, +follow this with `tune run quantize` while specifying the same quantizer in the config, e.g. + +```yaml +# QAT specific args +quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 +``` + +Currently only `torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer` +is supported. This refers to int8 dynamic per token activation quantization +combined with int4 grouped per axis weight quantization. For more details, +please refer to the [torchao implementation](https://github.com/pytorch/ao/blob/950a89388e88e10f26bbbbe2ec0b1710ba3d33d1/torchao/quantization/prototype/qat.py#L22). + +## Eval +To evaluate a quantized model, make the following changes to the default [evaluation config](configs/eleuther_evaluation.yaml) + + +```yaml +# Currently we only support torchtune checkpoints when +# evaluating quantized models. For more details on checkpointing see +# https://pytorch.org/torchtune/main/deep_dives/checkpointer.html +# Make sure to change the default checkpointer component +checkpointer: + _component_: torchtune.utils.FullModelTorchTuneCheckpointer + .. + checkpoint_files: [] + +# Quantization specific args +quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer + groupsize: 256 +``` + +Noet: we can use `Int8DynActInt4WeightQuantizer` to load a QAT quantized model since it's the same type of quantization. + +and run evaluation: +```bash +tune run eleuther_eval --config eleuther_evaluation +``` + +## Generate +To run inference using a quantized model, make the following changes to the default [generation config](configs/generation.yaml) + + +```yaml +# Currently we only support torchtune checkpoints when +# evaluating quantized models. For more details on checkpointing see +# https://pytorch.org/torchtune/main/deep_dives/checkpointer.html +# Make sure to change the default checkpointer component +checkpointer: + _component_: torchtune.utils.FullModelTorchTuneCheckpointer + .. + checkpoint_files: [] + +# Quantization Arguments +quantizer: + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer + groupsize: 256 +``` + +and run generation: +```bash +tune run generate --config generation +``` diff --git a/training/recipes/quantize.py b/training/recipes/quantize.py new file mode 100644 index 0000000..b83e8f7 --- /dev/null +++ b/training/recipes/quantize.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict + +import torch +from omegaconf import DictConfig + +from torch import nn + +from torchtune import config, utils + +logger = utils.get_logger("DEBUG") + + +class QuantizationRecipe: + """ + Recipe for quantizing a Transformer-based LLM. + Uses quantizer classes from torchao to quantize a model. + + Supported quantization modes are: + 8da4w (PyTorch 2.3+): + torchtune.utils.quantization.Int8DynActInt4WeightQuantizer + int8 per token dynamic activation with int4 weight only per axis group quantization + Args: + `groupsize` (int): a parameter of int4 weight only quantization, + it refers to the size of quantization groups which get independent quantization parameters + e.g. 32, 64, 128, 256, smaller numbers means more fine grained and higher accuracy, + but also higher memory overhead + + 8da4w-qat (PyTorch 2.4+): + torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer + int8 per token dynamic activation with int4 weight only per axis group quantization + Same as "8da4w", but for quantizing QAT checkpoints + Args: + `groupsize` (int): a parameter of int4 weight only quantization, + it refers to the size of quantization groups which get independent quantization parameters + e.g. 32, 64, 128, 256, smaller numbers means more fine grained and higher accuracy, + but also higher memory overhead + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype, device=self._device) + self._quantizer = config.instantiate(cfg.quantizer) + self._quantization_mode = utils.get_quantizer_mode(self._quantizer) + utils.set_seed(seed=cfg.seed) + + def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: + self._checkpointer = config.instantiate(checkpointer_cfg) + checkpoint_dict = self._checkpointer.load_checkpoint() + return checkpoint_dict + + def setup(self, cfg: DictConfig) -> None: + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + self._model = self._setup_model( + model_cfg=cfg.model, + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ) + + def _setup_model( + self, + model_cfg: DictConfig, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + + if "qat" in self._quantization_mode: + model = self._quantizer.prepare(model) + model.load_state_dict(model_state_dict) + + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + logger.info(f"Model is initialized with precision {self._dtype}.") + return model + + @torch.no_grad() + def quantize(self, cfg: DictConfig): + t0 = time.perf_counter() + if "qat" in self._quantization_mode: + self._model = self._quantizer.convert(self._model) + else: + self._model = self._quantizer.quantize(self._model) + t = time.perf_counter() - t0 + logger.info(f"Time for quantization: {t:.02f} sec") + logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + def save_checkpoint(self, cfg: DictConfig): + ckpt_dict = self._model.state_dict() + file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0] + + output_dir = Path(cfg.checkpointer.output_dir) + output_dir.mkdir(exist_ok=True) + checkpoint_file = Path.joinpath( + output_dir, f"{file_name}-{self._quantization_mode}".rstrip("-qat") + ).with_suffix(".pt") + + torch.save(ckpt_dict, checkpoint_file) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB " + f"saved to {checkpoint_file}" + ) + + +@config.parse +def main(cfg: DictConfig) -> None: + config.log_config(recipe_name="QuantizationRecipe", cfg=cfg) + recipe = QuantizationRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.quantize(cfg=cfg) + recipe.save_checkpoint(cfg=cfg) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/training/tests/__init__.py b/training/tests/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/assets/README.md b/training/tests/assets/README.md new file mode 100644 index 0000000..568c528 --- /dev/null +++ b/training/tests/assets/README.md @@ -0,0 +1,26 @@ +# Details on the assets in this folder + +## `m.model` + +**Description**: +**Creation**: +**Usage**: + + +## `tiny_fair_checkpoint.pt` + +**Description**: +**Creation**: +**Usage**: + +## `tiny_llama2_checkpoint.pt` + +**Description**: +**Creation**: +**Usage**: + +## `tiny_state_dict_with_one_key.pt` + +**Description**: +**Creation**: +**Usage**: diff --git a/training/tests/assets/alpaca_tiny.json b/training/tests/assets/alpaca_tiny.json new file mode 100644 index 0000000..8a71808 --- /dev/null +++ b/training/tests/assets/alpaca_tiny.json @@ -0,0 +1 @@ +[{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive three tips for staying healthy.\n\n### Response:\n1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."}, {"instruction": "What are the three primary colors?", "input": "", "output": "The three primary colors are red, blue, and yellow.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat are the three primary colors?\n\n### Response:\nThe three primary colors are red, blue, and yellow."}, {"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the structure of an atom.\n\n### Response:\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}, {"instruction": "How can we reduce air pollution?", "input": "", "output": "There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHow can we reduce air pollution?\n\n### Response:\nThere are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances."}, {"instruction": "Describe a time when you had to make a difficult decision.", "input": "", "output": "I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe a time when you had to make a difficult decision.\n\n### Response:\nI had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities."}, {"instruction": "Identify the odd one out.", "input": "Twitter, Instagram, Telegram", "output": "Telegram", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nIdentify the odd one out.\n\n### Input:\nTwitter, Instagram, Telegram\n\n### Response:\nTelegram"}, {"instruction": "Explain why the following fraction is equivalent to 1/4", "input": "4/16", "output": "The fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain why the following fraction is equivalent to 1/4\n\n### Input:\n4/16\n\n### Response:\nThe fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4."}, {"instruction": "Write a short story in third person narration about a protagonist who has to make an important career decision.", "input": "", "output": "John was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \n\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\n\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a short story in third person narration about a protagonist who has to make an important career decision.\n\n### Response:\nJohn was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \n\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\n\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him."}, {"instruction": "Render a 3D model of a house", "input": "", "output": " This type of instruction cannot be fulfilled by a GPT model.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nRender a 3D model of a house\n\n### Response:\n This type of instruction cannot be fulfilled by a GPT model."}, {"instruction": "Evaluate this sentence for spelling and grammar mistakes", "input": "He finnished his meal and left the resturant", "output": "He finished his meal and left the restaurant.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n### Input:\nHe finnished his meal and left the resturant\n\n### Response:\nHe finished his meal and left the restaurant."}, {"instruction": "How did Julius Caesar die?", "input": "", "output": "Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHow did Julius Caesar die?\n\n### Response:\nJulius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC."}, {"instruction": "What is the capital of France?", "input": "", "output": "The capital of France is Paris.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\nThe capital of France is Paris."}, {"instruction": "Generate a list of ten items a person might need for a camping trip", "input": "", "output": "1. Tent\n2. Sleeping bags\n3. Flashlight\n4. Matches/lighter\n5. Insect repellent\n6. Extra clothes\n7. First aid kit\n8. Camp stove\n9. Cooking utensils\n10. Biodegradable soaps/shampoos", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of ten items a person might need for a camping trip\n\n### Response:\n1. Tent\n2. Sleeping bags\n3. Flashlight\n4. Matches/lighter\n5. Insect repellent\n6. Extra clothes\n7. First aid kit\n8. Camp stove\n9. Cooking utensils\n10. Biodegradable soaps/shampoos"}, {"instruction": "Discuss the causes of the Great Depression", "input": "", "output": "The Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDiscuss the causes of the Great Depression\n\n### Response:\nThe Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade."}, {"instruction": "Classify the following into animals, plants, and minerals", "input": "Oak tree, copper ore, elephant", "output": "Oak tree: Plant\n Copper ore: Mineral\n Elephant: Animal", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nClassify the following into animals, plants, and minerals\n\n### Input:\nOak tree, copper ore, elephant\n\n### Response:\nOak tree: Plant\n Copper ore: Mineral\n Elephant: Animal"}, {"instruction": "Explain the use of word embeddings in Natural Language Processing", "input": "", "output": "Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the use of word embeddings in Natural Language Processing\n\n### Response:\nWord embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks."}, {"instruction": "Describe the function of a computer motherboard", "input": "", "output": "A computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the function of a computer motherboard\n\n### Response:\nA computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types."}, {"instruction": "Reverse engineer this code to create a new version", "input": "def factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial", "output": "def factorialize(num):\n factorial = 1\n for i in range(num, 0, -1):\n factorial *= i\n \n return factorial", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nReverse engineer this code to create a new version\n\n### Input:\ndef factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial\n\n### Response:\ndef factorialize(num):\n factorial = 1\n for i in range(num, 0, -1):\n factorial *= i\n \n return factorial"}, {"instruction": "Propose an ethical solution to the problem of data privacy", "input": "", "output": "One ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nPropose an ethical solution to the problem of data privacy\n\n### Response:\nOne ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data."}, {"instruction": "Generate three verbs that mean the same as \"to apologize\"", "input": "", "output": "Acknowledge, express regret, make amends.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate three verbs that mean the same as \"to apologize\"\n\n### Response:\nAcknowledge, express regret, make amends."}, {"instruction": "What does DNA stand for?", "input": "", "output": "DNA stands for deoxyribonucleic acid.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat does DNA stand for?\n\n### Response:\nDNA stands for deoxyribonucleic acid."}, {"instruction": "Compare and contrast the Cuban Missile Crisis and the Vietnam War.", "input": "", "output": "The Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCompare and contrast the Cuban Missile Crisis and the Vietnam War.\n\n### Response:\nThe Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force."}, {"instruction": "Generate a list of random words.", "input": "", "output": "sneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of random words.\n\n### Response:\nsneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic."}, {"instruction": "Transcribe the recording into text.", "input": "Recording about the history of Waterloo", "output": "The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nTranscribe the recording into text.\n\n### Input:\nRecording about the history of Waterloo\n\n### Response:\nThe Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years."}, {"instruction": "Who is the world's most famous painter?", "input": "", "output": "The world's most famous painter is generally considered to be the Italian renaissance artist Leonardo da Vinci. His work includes the Mona Lisa and The Last Supper, both of which are considered masterpieces. Da Vinci was also a master of science, inventing the helicopter and the tank, among other inventions. He was also a prolific musician, sculptor, and architect. His work has inspired generations of artists and continues to be revered today.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWho is the world's most famous painter?\n\n### Response:\nThe world's most famous painter is generally considered to be the Italian renaissance artist Leonardo da Vinci. His work includes the Mona Lisa and The Last Supper, both of which are considered masterpieces. Da Vinci was also a master of science, inventing the helicopter and the tank, among other inventions. He was also a prolific musician, sculptor, and architect. His work has inspired generations of artists and continues to be revered today."}, {"instruction": "Explain the concept of cogging torque.", "input": "", "output": "Cogging torque is the torque that is applied to a rotating device when the device experiences resistance from the teeth of the gears that are connected to the device. Cogging torque occurs due to the reaction force generated when the teeth of the gear drive mesh. This force pushes against each other, resulting in a drag force that is transmitted from the gears to the shaft. The magnitude of the cogging torque can vary depending on the size and number of teeth, as well as the gear's material and design. Cogging torque is undesirable in most motorized applications, as it can lead to loss of efficiency. To reduce cogging torque, gear designs are often symmetrical or have certain tooth profiles which reduce the force of contact between the teeth when they come into contact with each other.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the concept of cogging torque.\n\n### Response:\nCogging torque is the torque that is applied to a rotating device when the device experiences resistance from the teeth of the gears that are connected to the device. Cogging torque occurs due to the reaction force generated when the teeth of the gear drive mesh. This force pushes against each other, resulting in a drag force that is transmitted from the gears to the shaft. The magnitude of the cogging torque can vary depending on the size and number of teeth, as well as the gear's material and design. Cogging torque is undesirable in most motorized applications, as it can lead to loss of efficiency. To reduce cogging torque, gear designs are often symmetrical or have certain tooth profiles which reduce the force of contact between the teeth when they come into contact with each other."}, {"instruction": "Look up the boiling point of water.", "input": "", "output": "The boiling point of water is 100\u00b0C (212\u00b0F), at standard atmospheric pressure. When heated at higher pressures, water's boiling point increases. For example, water will boil at a temperature of 115\u00b0C (239\u00b0F) at an atmospheric pressure of 1 bar. At higher pressures, such as those experienced at the bottom of the ocean, the boiling point of water can be as high as 374\u00b0C (705\u00b0F).", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nLook up the boiling point of water.\n\n### Response:\nThe boiling point of water is 100\u00b0C (212\u00b0F), at standard atmospheric pressure. When heated at higher pressures, water's boiling point increases. For example, water will boil at a temperature of 115\u00b0C (239\u00b0F) at an atmospheric pressure of 1 bar. At higher pressures, such as those experienced at the bottom of the ocean, the boiling point of water can be as high as 374\u00b0C (705\u00b0F)."}, {"instruction": "Describe the life and reign of King Charles II.", "input": "", "output": "King Charles II ascended to the throne of England following the English Civil War and the death of his father, Charles I. He was crowned in 1661 and his reign lasted until his death in 1685. Charles II was also a patron of the arts, working with some of the greatest painters and sculptors of his day. He also encouraged theater and sports and had an interest in sciences such as astronomy and natural history. Charles II was a moderate ruler who respected the civil and religious liberties of his citizens. He was tolerant of other religions and gave Jews and Catholics more rights in the kingdom. Charles II was also a strong supporter of the Royal Navy, which saw England increase its power in the world. He also promoted the development of industry and commerce, which helped to modernize the country. Charles II's reign was an important step in restoring the monarchy of England after the turmoil of the English Civil War.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the life and reign of King Charles II.\n\n### Response:\nKing Charles II ascended to the throne of England following the English Civil War and the death of his father, Charles I. He was crowned in 1661 and his reign lasted until his death in 1685. Charles II was also a patron of the arts, working with some of the greatest painters and sculptors of his day. He also encouraged theater and sports and had an interest in sciences such as astronomy and natural history. Charles II was a moderate ruler who respected the civil and religious liberties of his citizens. He was tolerant of other religions and gave Jews and Catholics more rights in the kingdom. Charles II was also a strong supporter of the Royal Navy, which saw England increase its power in the world. He also promoted the development of industry and commerce, which helped to modernize the country. Charles II's reign was an important step in restoring the monarchy of England after the turmoil of the English Civil War."}, {"instruction": "Find the area of a circle given its radius.", "input": "Radius = 4", "output": "The area of a circle with a radius of 4 is equal to 12.5664 square units. This is calculated by using the formula A = \u03c0r2, where A is the area, \u03c0 is roughly equal to 3.1416 and r is the radius of the circle.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nFind the area of a circle given its radius.\n\n### Input:\nRadius = 4\n\n### Response:\nThe area of a circle with a radius of 4 is equal to 12.5664 square units. This is calculated by using the formula A = \u03c0r2, where A is the area, \u03c0 is roughly equal to 3.1416 and r is the radius of the circle."}, {"instruction": "Identify the lines of longitude that divides North and South America.", "input": "", "output": "The lines of longitude that divides North and South America is the line of the Panama Canal. This line of longitude cuts through the Isthmus of Panama and separates South America from Central America. It is approximately 800 km (500 mi) in total length and was completed in 1914. The Panama Canal is an important shipping route and serves as an international waterway connecting the Caribbean Sea with the Pacific Ocean.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nIdentify the lines of longitude that divides North and South America.\n\n### Response:\nThe lines of longitude that divides North and South America is the line of the Panama Canal. This line of longitude cuts through the Isthmus of Panama and separates South America from Central America. It is approximately 800 km (500 mi) in total length and was completed in 1914. The Panama Canal is an important shipping route and serves as an international waterway connecting the Caribbean Sea with the Pacific Ocean."}, {"instruction": "Explain the process of cellular respiration in plants.", "input": "", "output": "Cellular respiration in plants is a complex process in which the plant converts glucose into usable energy. The process begins when the plant takes in carbon dioxide from the environment and combines it with water to form glucose molecules using the energy from sunlight. This glucose is then broken down into simpler molecules of carbon dioxide and water. During this process, the energy released is used to produce adenosine triphosphate (ATP) molecules. ATP is used by the plant to power other processes, such as growth and reproduction. Additionally, the ATP can also be broken down further to produce heat energy, which helps to sustain the plant's temperature. The carbon dioxide and water molecules produced in the process are then released back into the environment to complete the cycle.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the process of cellular respiration in plants.\n\n### Response:\nCellular respiration in plants is a complex process in which the plant converts glucose into usable energy. The process begins when the plant takes in carbon dioxide from the environment and combines it with water to form glucose molecules using the energy from sunlight. This glucose is then broken down into simpler molecules of carbon dioxide and water. During this process, the energy released is used to produce adenosine triphosphate (ATP) molecules. ATP is used by the plant to power other processes, such as growth and reproduction. Additionally, the ATP can also be broken down further to produce heat energy, which helps to sustain the plant's temperature. The carbon dioxide and water molecules produced in the process are then released back into the environment to complete the cycle."}, {"instruction": "Rewrite the following sentence using active voice.", "input": "The news report was read by the captain.", "output": "The captain read the news report.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nRewrite the following sentence using active voice.\n\n### Input:\nThe news report was read by the captain.\n\n### Response:\nThe captain read the news report."}, {"instruction": "Generate a list of adjectives that describe a person as brave.", "input": "", "output": "Courageous, heroic, audacious, vigorous, valorous, resolute, intrepid, bold, brave, daring, tenacious, plucky, stout-hearted, doughty, gallant, fearless, undaunted, spirited, daredevil.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of adjectives that describe a person as brave.\n\n### Response:\nCourageous, heroic, audacious, vigorous, valorous, resolute, intrepid, bold, brave, daring, tenacious, plucky, stout-hearted, doughty, gallant, fearless, undaunted, spirited, daredevil."}, {"instruction": "Outline the consequences of deforestation.", "input": "", "output": "Deforestation is the process of clearing and removing forests. The consequences of deforestation can include soil erosion, reduced water quality, decreased biodiversity, increased carbon dioxide in the atmosphere and loss of habitat for wildlife. Soil erosion occurs when the trees that help retain water and soil are removed, leading to soil particles being washed away. This can cause siltation in rivers and streams, affecting water quality and leading to flooding. Additionally, when trees are removed, wildlife lose their habitat. This can cause species to become endangered, leading to further loss of biodiversity. Deforestation can also lead to increased carbon dioxide in the atmosphere as there are fewer trees to take up the gas during photosynthesis, which can contribute to climate change.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nOutline the consequences of deforestation.\n\n### Response:\nDeforestation is the process of clearing and removing forests. The consequences of deforestation can include soil erosion, reduced water quality, decreased biodiversity, increased carbon dioxide in the atmosphere and loss of habitat for wildlife. Soil erosion occurs when the trees that help retain water and soil are removed, leading to soil particles being washed away. This can cause siltation in rivers and streams, affecting water quality and leading to flooding. Additionally, when trees are removed, wildlife lose their habitat. This can cause species to become endangered, leading to further loss of biodiversity. Deforestation can also lead to increased carbon dioxide in the atmosphere as there are fewer trees to take up the gas during photosynthesis, which can contribute to climate change."}, {"instruction": "Develop a plan to reduce electricity usage in a home.", "input": "", "output": "A plan to reduce electricity usage in a home can include a number of different strategies. These include replacing inefficient lightbulbs with energy-efficient LEDs, unplugging appliances and electronics when not in use, using natural light instead of electric lights where possible, using a touch-free faucet to reduce water heater use, using a smart thermostat to monitor and adjust temperatures automatically, installing a solar hot water heater, and turning off electronics when not in use. Additionally, making sure to install and seal windows and doors correctly to reduce air leakage can also help reduce electricity usage.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDevelop a plan to reduce electricity usage in a home.\n\n### Response:\nA plan to reduce electricity usage in a home can include a number of different strategies. These include replacing inefficient lightbulbs with energy-efficient LEDs, unplugging appliances and electronics when not in use, using natural light instead of electric lights where possible, using a touch-free faucet to reduce water heater use, using a smart thermostat to monitor and adjust temperatures automatically, installing a solar hot water heater, and turning off electronics when not in use. Additionally, making sure to install and seal windows and doors correctly to reduce air leakage can also help reduce electricity usage."}, {"instruction": "Arrange the words in the given sentence to form a grammatically correct sentence.", "input": "quickly the brown fox jumped", "output": "The quick brown fox jumped quickly.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nArrange the words in the given sentence to form a grammatically correct sentence.\n\n### Input:\nquickly the brown fox jumped\n\n### Response:\nThe quick brown fox jumped quickly."}, {"instruction": "Analyze the given text for its tone.", "input": "The world has been greatly impacted by the COVID-19 pandemic and it has drastically changed our lives.", "output": "The tone of the text is one of concern and recognition of the dramatic effects of the COVID-19 pandemic on the world. It is an acknowledgement of the upheaval that the virus has caused and that our lives have been drastically changed.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nAnalyze the given text for its tone.\n\n### Input:\nThe world has been greatly impacted by the COVID-19 pandemic and it has drastically changed our lives.\n\n### Response:\nThe tone of the text is one of concern and recognition of the dramatic effects of the COVID-19 pandemic on the world. It is an acknowledgement of the upheaval that the virus has caused and that our lives have been drastically changed."}, {"instruction": "Use the given data to calculate the median.", "input": "[2, 3, 7, 8, 10]", "output": "The median of the given data is 7.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nUse the given data to calculate the median.\n\n### Input:\n[2, 3, 7, 8, 10]\n\n### Response:\nThe median of the given data is 7."}, {"instruction": "Rewrite the given paragraph in a shorter, easier to understand form.", "input": "Although it is generally accepted that the internet has allowed us to connect with people all over the world, there are still those people who are not familiar with its basic functions, who don\u2019t understand why it has become so commonplace, or what its true capabilities are.", "output": "The internet has allowed us to connect globally, even though some people are unfamiliar with its functions, don't comprehend its widespread use, and don't know its full potential.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nRewrite the given paragraph in a shorter, easier to understand form.\n\n### Input:\nAlthough it is generally accepted that the internet has allowed us to connect with people all over the world, there are still those people who are not familiar with its basic functions, who don\u2019t understand why it has become so commonplace, or what its true capabilities are.\n\n### Response:\nThe internet has allowed us to connect globally, even though some people are unfamiliar with its functions, don't comprehend its widespread use, and don't know its full potential."}, {"instruction": "Design a logo for a website about recycling.", "input": "", "output": "The logo should feature a green motif that is recognizable and memorable, such as a recycle symbol or a stylized representation of a recycling bin. The logo should also include text that communicates the website's purpose, such as \"Recycling\" or \"Go Green\".", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDesign a logo for a website about recycling.\n\n### Response:\nThe logo should feature a green motif that is recognizable and memorable, such as a recycle symbol or a stylized representation of a recycling bin. The logo should also include text that communicates the website's purpose, such as \"Recycling\" or \"Go Green\"."}, {"instruction": "Generate a poem that expresses joy.", "input": "", "output": "Joy flows through me like a river clear,\nBringing laughter and bliss to my soul to hear.\n\nThe world is bright and wonderful, I can see,\nMaking every moment something happy and free.\n\nMy heart feels light, my cares have flown,\nAs I discover what happiness gives me alone.\n\nI open my arms and accept the joy,\nBringing life to my song, so sweet like a toy.\n\nEvery moment is a precious gift,\nThat I savor and never let drift.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a poem that expresses joy.\n\n### Response:\nJoy flows through me like a river clear,\nBringing laughter and bliss to my soul to hear.\n\nThe world is bright and wonderful, I can see,\nMaking every moment something happy and free.\n\nMy heart feels light, my cares have flown,\nAs I discover what happiness gives me alone.\n\nI open my arms and accept the joy,\nBringing life to my song, so sweet like a toy.\n\nEvery moment is a precious gift,\nThat I savor and never let drift."}, {"instruction": "Convert the given equation into an algebraic expression.", "input": "3x+5y=9", "output": "3x + 5y - 9 = 0", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConvert the given equation into an algebraic expression.\n\n### Input:\n3x+5y=9\n\n### Response:\n3x + 5y - 9 = 0"}, {"instruction": "Brainstorm possible solutions to reduce water pollution.", "input": "", "output": "- Implement stricter regulations on industrial waste disposal\n- Increase the use of green technology in wastewater treatment\n- Provide incentives to businesses that reduce their water consumption and waste generation\n- Educate citizens about water conservation and pollution prevention\n- Offer recycling programs to reduce plastic waste\n- Replace disposable plastics with biodegradable alternatives\n- Improve sewer systems\n- Re-introduce native plant species\n- Foster artificial wetlands and green infrastructure\n- Provide filtration systems at water treatment plants\n- Implement best practices for agricultural runoff\n- Increase research and monitoring of water quality\n- Establish buffer zones between aquatic environments and land-based activities.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nBrainstorm possible solutions to reduce water pollution.\n\n### Response:\n- Implement stricter regulations on industrial waste disposal\n- Increase the use of green technology in wastewater treatment\n- Provide incentives to businesses that reduce their water consumption and waste generation\n- Educate citizens about water conservation and pollution prevention\n- Offer recycling programs to reduce plastic waste\n- Replace disposable plastics with biodegradable alternatives\n- Improve sewer systems\n- Re-introduce native plant species\n- Foster artificial wetlands and green infrastructure\n- Provide filtration systems at water treatment plants\n- Implement best practices for agricultural runoff\n- Increase research and monitoring of water quality\n- Establish buffer zones between aquatic environments and land-based activities."}, {"instruction": "Explain why the given definition is wrong.", "input": "A mole is an animal that lives underground.", "output": "The definition is wrong because a mole is actually a small mammal that lives above ground and feeds on insects. It is most distinguishable by its long snout and wide, paddle-like feet. It does, however, often dig burrows in which it lives and shelters.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain why the given definition is wrong.\n\n### Input:\nA mole is an animal that lives underground.\n\n### Response:\nThe definition is wrong because a mole is actually a small mammal that lives above ground and feeds on insects. It is most distinguishable by its long snout and wide, paddle-like feet. It does, however, often dig burrows in which it lives and shelters."}, {"instruction": "Variable x is defined as \u201c4x + 2y = 10\u201d. Find the value of x.", "input": "", "output": "The value of x is 2. To find the value, simplify the equation by subtracting 2y from both sides, giving 4x = 10; dividing both sides by 4, giving x = 2/4, which is equal to 2.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nVariable x is defined as \u201c4x + 2y = 10\u201d. Find the value of x.\n\n### Response:\nThe value of x is 2. To find the value, simplify the equation by subtracting 2y from both sides, giving 4x = 10; dividing both sides by 4, giving x = 2/4, which is equal to 2."}, {"instruction": "Write a short paragraph about the given topic.", "input": "The importance of using renewable energy", "output": "The use of renewable energy is growing rapidly in relevance and importance as the world looks towards solutions to combat climate change. Renewable energy sources, such as solar, wind, and hydropower, are sustainable and have a much lower environmental impact than traditional energy sources like coal and oil. Switching to renewable energy can reduce carbon emissions, improve air quality, lessen the dependence on non-renewable resources, and provide a stable and secure energy future. Renewable energy is essential for achieving a sustainable future and should be an important part of any country\u2019s energy portfolio.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a short paragraph about the given topic.\n\n### Input:\nThe importance of using renewable energy\n\n### Response:\nThe use of renewable energy is growing rapidly in relevance and importance as the world looks towards solutions to combat climate change. Renewable energy sources, such as solar, wind, and hydropower, are sustainable and have a much lower environmental impact than traditional energy sources like coal and oil. Switching to renewable energy can reduce carbon emissions, improve air quality, lessen the dependence on non-renewable resources, and provide a stable and secure energy future. Renewable energy is essential for achieving a sustainable future and should be an important part of any country\u2019s energy portfolio."}, {"instruction": "Explain the concept of artificial intelligence in simple terms.", "input": "", "output": "Artificial Intelligence (AI) is the ability of machines and computers to make decisions and complete tasks similar to what humans can do. AI can learn from data, recognize patterns, and make predictions. AI technology can automate tedious and complex processes, helping to reduce human effort and errors, and enable more accurate and efficient decisions. AI can be used in a wide range of applications, from robotics and autonomous vehicles to healthcare and finance. It is increasingly becoming an integral part of everyday life.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the concept of artificial intelligence in simple terms.\n\n### Response:\nArtificial Intelligence (AI) is the ability of machines and computers to make decisions and complete tasks similar to what humans can do. AI can learn from data, recognize patterns, and make predictions. AI technology can automate tedious and complex processes, helping to reduce human effort and errors, and enable more accurate and efficient decisions. AI can be used in a wide range of applications, from robotics and autonomous vehicles to healthcare and finance. It is increasingly becoming an integral part of everyday life."}, {"instruction": "Design an app for a delivery company.", "input": "", "output": "The app should allow customers to order items from local businesses and nearby store, place delivery requests with detailed delivery requirements, and track their orders. It should also allow customers to easily find and pay for services. The app should allow businesses to manage orders, track deliveries and drivers, set delivery boundaries, and set prices. It should also have features to help businesses promote their products and services, as well as options to provide customer feedback and support.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDesign an app for a delivery company.\n\n### Response:\nThe app should allow customers to order items from local businesses and nearby store, place delivery requests with detailed delivery requirements, and track their orders. It should also allow customers to easily find and pay for services. The app should allow businesses to manage orders, track deliveries and drivers, set delivery boundaries, and set prices. It should also have features to help businesses promote their products and services, as well as options to provide customer feedback and support."}, {"instruction": "Summarize the given passage.", "input": "A recent study showed that global climate change is one of the most important challenges facing the world today. The consequences of global warming include rising sea levels, extreme weather events and biodiversity loss. Climate change is caused by human activities such as burning fossil fuels, and if unchecked will have long-term and devastating impacts on the planet.", "output": "A recent study revealed that global climate change is one of the world\u2019s most pressing issues. In addition to rising sea levels, extreme weather events and biodiversity loss, it is caused largely by human activities like burning fossil fuels. If left unchecked, climate change will have far-reaching and devastating consequences.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nSummarize the given passage.\n\n### Input:\nA recent study showed that global climate change is one of the most important challenges facing the world today. The consequences of global warming include rising sea levels, extreme weather events and biodiversity loss. Climate change is caused by human activities such as burning fossil fuels, and if unchecked will have long-term and devastating impacts on the planet.\n\n### Response:\nA recent study revealed that global climate change is one of the world\u2019s most pressing issues. In addition to rising sea levels, extreme weather events and biodiversity loss, it is caused largely by human activities like burning fossil fuels. If left unchecked, climate change will have far-reaching and devastating consequences."}, {"instruction": "Extract the facts from the paragraph.", "input": "Online education continues to become more popular for schools and students alike. Its advantages are generally lower costs, less commitment and the ability to study at a time, place and pace that suits the student.", "output": "- Online education is becoming increasingly popular.\n- It has several advantages such as lower costs, less commitment and the ability to study at one\u2019s own time and pace.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nExtract the facts from the paragraph.\n\n### Input:\nOnline education continues to become more popular for schools and students alike. Its advantages are generally lower costs, less commitment and the ability to study at a time, place and pace that suits the student.\n\n### Response:\n- Online education is becoming increasingly popular.\n- It has several advantages such as lower costs, less commitment and the ability to study at one\u2019s own time and pace."}, {"instruction": "Edit the following sentence to make it more concise.", "input": "He ran to the bus stop in order to catch the bus that was due to arrive in five minutes.", "output": "He ran to the bus stop, due to arrive in five minutes.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nEdit the following sentence to make it more concise.\n\n### Input:\nHe ran to the bus stop in order to catch the bus that was due to arrive in five minutes.\n\n### Response:\nHe ran to the bus stop, due to arrive in five minutes."}, {"instruction": "Generate a poem with 10 lines.", "input": "", "output": "The evening light shimmers on the shore\nSoftly the waves echoes around and more \nAs I bask in the sun, my worries are all gone\nThe sound of seagulls I now foolishly ignore \nGlistening sand, beckons me with a silent plea \nGlistening seawater, cool to the touch and refreshingly free \nThe evening brings peace, yet I can't find any \nBut maybe in the morning there'll be time for me\nMy bottled peacefulness, I uncork and pour \nThe sound of the ocean, lulls me even more \nAnd for just a moment I close my eyes and behold \nThe vastness of the ocean, to my soul I now unfold.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a poem with 10 lines.\n\n### Response:\nThe evening light shimmers on the shore\nSoftly the waves echoes around and more \nAs I bask in the sun, my worries are all gone\nThe sound of seagulls I now foolishly ignore \nGlistening sand, beckons me with a silent plea \nGlistening seawater, cool to the touch and refreshingly free \nThe evening brings peace, yet I can't find any \nBut maybe in the morning there'll be time for me\nMy bottled peacefulness, I uncork and pour \nThe sound of the ocean, lulls me even more \nAnd for just a moment I close my eyes and behold \nThe vastness of the ocean, to my soul I now unfold."}, {"instruction": "Convert from celsius to fahrenheit.", "input": "Temperature in Celsius: 15", "output": "Temperature in Fahrenheit: 59", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConvert from celsius to fahrenheit.\n\n### Input:\nTemperature in Celsius: 15\n\n### Response:\nTemperature in Fahrenheit: 59"}, {"instruction": "Arrange the given numbers in ascending order.", "input": "2, 4, 0, 8, 3", "output": "0, 2, 3, 4, 8", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nArrange the given numbers in ascending order.\n\n### Input:\n2, 4, 0, 8, 3\n\n### Response:\n0, 2, 3, 4, 8"}, {"instruction": "Calculate the total surface area of a cube with a side length of 5 cm.", "input": "", "output": "Total surface area of the cube: 150 cm^2", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCalculate the total surface area of a cube with a side length of 5 cm.\n\n### Response:\nTotal surface area of the cube: 150 cm^2"}, {"instruction": "What is the force on a 1 kg mass due to the gravitational force?", "input": "", "output": "The force on a 1 kg mass due to the gravitational force is 9.8 N.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is the force on a 1 kg mass due to the gravitational force?\n\n### Response:\nThe force on a 1 kg mass due to the gravitational force is 9.8 N."}, {"instruction": "Provide one example for a cultural practice.", "input": "", "output": "An example of a cultural practice is the Japanese tea ceremony.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nProvide one example for a cultural practice.\n\n### Response:\nAn example of a cultural practice is the Japanese tea ceremony."}, {"instruction": "Given a set of numbers, find the maximum value.", "input": "Set: {10, 3, 25, 6, 16}", "output": "The maximum value is 25.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGiven a set of numbers, find the maximum value.\n\n### Input:\nSet: {10, 3, 25, 6, 16}\n\n### Response:\nThe maximum value is 25."}, {"instruction": "Give two examples of a liquid.", "input": "", "output": "Two examples of liquids are water and oil.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive two examples of a liquid.\n\n### Response:\nTwo examples of liquids are water and oil."}, {"instruction": "What is the product of 6 and 2?", "input": "", "output": "The product of 6 and 2 is 12.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is the product of 6 and 2?\n\n### Response:\nThe product of 6 and 2 is 12."}, {"instruction": "What type of plant is a skunk cabbage?", "input": "", "output": "Skunk cabbage is a type of perennial flowering plant.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat type of plant is a skunk cabbage?\n\n### Response:\nSkunk cabbage is a type of perennial flowering plant."}, {"instruction": "Convert the given binary number to its decimal equivalent.", "input": "101101", "output": "The decimal equivalent of 101101 is 45.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConvert the given binary number to its decimal equivalent.\n\n### Input:\n101101\n\n### Response:\nThe decimal equivalent of 101101 is 45."}, {"instruction": "Name two types of desert biomes.", "input": "", "output": "Two types of desert biomes are xeric and subpolar deserts.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nName two types of desert biomes.\n\n### Response:\nTwo types of desert biomes are xeric and subpolar deserts."}, {"instruction": "Given a sentence, convert it into passive voice.", "input": "The dog barked at the alley cat.", "output": "The alley cat was barked at by the dog.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGiven a sentence, convert it into passive voice.\n\n### Input:\nThe dog barked at the alley cat.\n\n### Response:\nThe alley cat was barked at by the dog."}, {"instruction": "Transform the following sentence into the passive voice", "input": "I bought a book", "output": "A book was bought by me.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nTransform the following sentence into the passive voice\n\n### Input:\nI bought a book\n\n### Response:\nA book was bought by me."}, {"instruction": "Create a dialog between two people who are discussing a scientific phenomenon", "input": "Hydropower", "output": "Person 1: Did you know that hydropower is now being used to generate electricity?\nPerson 2: Really? That\u2019s incredible! How does it work?\nPerson: 1: Well, hydropower utilizes the energy of moving water to generate electricity. By capturing and harnessing the energy of falling or flowing water, hydropower can be used for efficient electricity generation.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a dialog between two people who are discussing a scientific phenomenon\n\n### Input:\nHydropower\n\n### Response:\nPerson 1: Did you know that hydropower is now being used to generate electricity?\nPerson 2: Really? That\u2019s incredible! How does it work?\nPerson: 1: Well, hydropower utilizes the energy of moving water to generate electricity. By capturing and harnessing the energy of falling or flowing water, hydropower can be used for efficient electricity generation."}, {"instruction": "Identify the most suitable adverb for the following sentence", "input": "She works", "output": "She works diligently.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nIdentify the most suitable adverb for the following sentence\n\n### Input:\nShe works\n\n### Response:\nShe works diligently."}, {"instruction": "Find the main idea of the following passage", "input": "Climate change is an increasingly contentious global issue. It is already having a drastic effect on the environment and human lives. We are already seeing an increase in extreme weather events and rising sea levels, both of which have a major impact on the planet.", "output": "The main idea of the passage is that climate change is a global issue that is having a detrimental effect on the environment and has caused an increase in extreme weather events and rising sea levels.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nFind the main idea of the following passage\n\n### Input:\nClimate change is an increasingly contentious global issue. It is already having a drastic effect on the environment and human lives. We are already seeing an increase in extreme weather events and rising sea levels, both of which have a major impact on the planet.\n\n### Response:\nThe main idea of the passage is that climate change is a global issue that is having a detrimental effect on the environment and has caused an increase in extreme weather events and rising sea levels."}, {"instruction": "Analyze the tone of the following sentences", "input": "I am so upset with you right now", "output": "The tone of the sentences is angry and accusatory.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nAnalyze the tone of the following sentences\n\n### Input:\nI am so upset with you right now\n\n### Response:\nThe tone of the sentences is angry and accusatory."}, {"instruction": "Construct an argument to defend the following statement", "input": "Alternative energy sources are critical to solving the climate crisis", "output": "Alternative energy sources are absolutely vital in order to mitigate the most detrimental impacts of climate change. As the world transitions away from fossil fuels and towards more sustainable sources of energy, such as wind, solar, and hydropower, we can reduce the number of greenhouse gases in the atmosphere and prevent further warming of our planet. Additionally, using renewable and alternative energies sources will create thousands of jobs and can help to create a healthy, stable economy. Investing in alternative energy sources is a sound decision and a necessary step if we want to effectively address the climate crisis.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConstruct an argument to defend the following statement\n\n### Input:\nAlternative energy sources are critical to solving the climate crisis\n\n### Response:\nAlternative energy sources are absolutely vital in order to mitigate the most detrimental impacts of climate change. As the world transitions away from fossil fuels and towards more sustainable sources of energy, such as wind, solar, and hydropower, we can reduce the number of greenhouse gases in the atmosphere and prevent further warming of our planet. Additionally, using renewable and alternative energies sources will create thousands of jobs and can help to create a healthy, stable economy. Investing in alternative energy sources is a sound decision and a necessary step if we want to effectively address the climate crisis."}, {"instruction": "Convert the following sentence into the present continuous tense", "input": "He reads books", "output": "He is reading books.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConvert the following sentence into the present continuous tense\n\n### Input:\nHe reads books\n\n### Response:\nHe is reading books."}, {"instruction": "Give an example of a metaphor that uses the following object", "input": "Stars", "output": "The stars were twinkling diamonds in the night sky.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGive an example of a metaphor that uses the following object\n\n### Input:\nStars\n\n### Response:\nThe stars were twinkling diamonds in the night sky."}, {"instruction": "Describe the following person", "input": "John", "output": "John is a tall, dark-haired man with a kind smile and a twinkle in his eyes. He has a gentle, caring nature, and he always strives to understand others. He has a passion for learning and exploring new places, experiences and cultures.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the following person\n\n### Input:\nJohn\n\n### Response:\nJohn is a tall, dark-haired man with a kind smile and a twinkle in his eyes. He has a gentle, caring nature, and he always strives to understand others. He has a passion for learning and exploring new places, experiences and cultures."}, {"instruction": "Construct a mathematical problem with the following numbers", "input": "21, 4", "output": "If 21 pieces of candy are divided among 4 people, how many pieces of candy does each person get?", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConstruct a mathematical problem with the following numbers\n\n### Input:\n21, 4\n\n### Response:\nIf 21 pieces of candy are divided among 4 people, how many pieces of candy does each person get?"}, {"instruction": "Aim to reduce the following sentence without changing its meaning", "input": "When he came in, he woke everyone up", "output": "He woke everyone on entering.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nAim to reduce the following sentence without changing its meaning\n\n### Input:\nWhen he came in, he woke everyone up\n\n### Response:\nHe woke everyone on entering."}, {"instruction": "Identify the conjunctions in the following sentence", "input": "She wrote a letter and sealed it", "output": "The conjunctions in the sentence are \"and\".", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nIdentify the conjunctions in the following sentence\n\n### Input:\nShe wrote a letter and sealed it\n\n### Response:\nThe conjunctions in the sentence are \"and\"."}, {"instruction": "Rewrite the following sentence in the third person", "input": "I am anxious", "output": "She is anxious.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nRewrite the following sentence in the third person\n\n### Input:\nI am anxious\n\n### Response:\nShe is anxious."}, {"instruction": "Generate a list of business ideas for a food delivery service.", "input": "", "output": "- Offer discounts and loyalty programs for frequent customers.\n- Partner with local restaurants and promote their food.\n- Train delivery personnel to handle orders and provide good customer service.\n- Use mobile apps and websites for ordering.\n- Offer diverse payment options for customers.\n- Allow customers to customize their orders.\n- Pack food in recyclable and compostable packaging.\n- Launch a subscription model for regular customers.\n- Offer special services like food suggestions based on customers' preferences.\n- Provide contactless delivery and curbside pickup.\n- Provide catering services for large orders and special events.\n- Use online marketing and social media platforms to reach more customers.\n- Create incentives to encourage customers to try new dishes.\n- Create meal kits with all the necessary ingredients and a recipe.\n- Make use of food delivery robots and drones.\n- Offer a loyalty points system for customers.\n- Collaborate with local farmers to source produce and ingredients.\n- Offer vegetarian, vegan, and gluten-free options.\n- Allow customers to select the delivery time.\n- Implement AI-powered recommendations for customers.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of business ideas for a food delivery service.\n\n### Response:\n- Offer discounts and loyalty programs for frequent customers.\n- Partner with local restaurants and promote their food.\n- Train delivery personnel to handle orders and provide good customer service.\n- Use mobile apps and websites for ordering.\n- Offer diverse payment options for customers.\n- Allow customers to customize their orders.\n- Pack food in recyclable and compostable packaging.\n- Launch a subscription model for regular customers.\n- Offer special services like food suggestions based on customers' preferences.\n- Provide contactless delivery and curbside pickup.\n- Provide catering services for large orders and special events.\n- Use online marketing and social media platforms to reach more customers.\n- Create incentives to encourage customers to try new dishes.\n- Create meal kits with all the necessary ingredients and a recipe.\n- Make use of food delivery robots and drones.\n- Offer a loyalty points system for customers.\n- Collaborate with local farmers to source produce and ingredients.\n- Offer vegetarian, vegan, and gluten-free options.\n- Allow customers to select the delivery time.\n- Implement AI-powered recommendations for customers."}, {"instruction": "Edit the following sentence to improve clarity and flow.", "input": "The restaurant has great food though the prices are a bit hefty", "output": "The restaurant offers great food but at a slightly higher price.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nEdit the following sentence to improve clarity and flow.\n\n### Input:\nThe restaurant has great food though the prices are a bit hefty\n\n### Response:\nThe restaurant offers great food but at a slightly higher price."}, {"instruction": "Imagine you are speaking with a customer who is unsatisfied with the product they bought from you. Construct a response that diffuses the situation.", "input": "", "output": "I'm sorry that you're not happy with your purchase. We want all of our customers to be satisfied, so I'd like to offer you a full refund, as well as a 15% discount for your next purchase. Is there anything else I can do to make this right for you?", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nImagine you are speaking with a customer who is unsatisfied with the product they bought from you. Construct a response that diffuses the situation.\n\n### Response:\nI'm sorry that you're not happy with your purchase. We want all of our customers to be satisfied, so I'd like to offer you a full refund, as well as a 15% discount for your next purchase. Is there anything else I can do to make this right for you?"}, {"instruction": "Explain the concept of a bubble sort algorithm to a non-technical audience.", "input": "", "output": "A bubble sort algorithm is a type of sorting algorithm that is used to sort elements in an array. It works by looking at each element of the array and comparing it to the next element. If the first element is bigger than the second element, they are swapped. This process is repeated until the whole array is sorted. This type of sorting is one of the simplest sorting algorithms, but it can be slow if the array has many elements.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the concept of a bubble sort algorithm to a non-technical audience.\n\n### Response:\nA bubble sort algorithm is a type of sorting algorithm that is used to sort elements in an array. It works by looking at each element of the array and comparing it to the next element. If the first element is bigger than the second element, they are swapped. This process is repeated until the whole array is sorted. This type of sorting is one of the simplest sorting algorithms, but it can be slow if the array has many elements."}, {"instruction": "Classify the following incident as a breach of protocol. Output 1 for breach, and 0 for no breach.", "input": "Using a school laptop for personal use", "output": "1", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nClassify the following incident as a breach of protocol. Output 1 for breach, and 0 for no breach.\n\n### Input:\nUsing a school laptop for personal use\n\n### Response:\n1"}, {"instruction": "Given the following input, construct a creative story.", "input": "A magic bow and arrow", "output": "Once upon a time, there lived a young girl named Alexa who was gifted with an incredible magical bow and arrow. Whenever she pulled back her bow and let an arrow fly, wherever it landed, something extraordinary happened. Flowers bloomed, oceans calmed, and the sun shone brighter. Alexa's bow and arrow were so powerful, that it could make even the most impossible things possible. One day, during a great storm, Alexa used her magical bow and arrow to bring calm and harmony to her hometown. She was praised and celebrated by the whole town, and she soon became a symbol of hope and peace.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGiven the following input, construct a creative story.\n\n### Input:\nA magic bow and arrow\n\n### Response:\nOnce upon a time, there lived a young girl named Alexa who was gifted with an incredible magical bow and arrow. Whenever she pulled back her bow and let an arrow fly, wherever it landed, something extraordinary happened. Flowers bloomed, oceans calmed, and the sun shone brighter. Alexa's bow and arrow were so powerful, that it could make even the most impossible things possible. One day, during a great storm, Alexa used her magical bow and arrow to bring calm and harmony to her hometown. She was praised and celebrated by the whole town, and she soon became a symbol of hope and peace."}, {"instruction": "Generate a list of methods to reduce food waste.", "input": "", "output": "- Shop smarter and only buy what you need.\n- Save leftovers for later and freeze them if necessary. \n- Monitor expiration dates and consume items before they go bad.\n- Participate in community or public food sharing initiatives.\n- Compost leftover food scraps.\n- Reuse containers, jars, and bags.\n- Make use of restaurant discounts to reduce over-ordering. \n- Donate foods that have short expiration dates to local food banks. \n- Buy more produce that is close to the expiration date and prepare it soon.\n- Order small portions while dining out.\n- Educate yourself and others about food waste.\n- Plan meals in advance and base your grocery list on the meals. \n- Store food correctly to reduce spoilage.\n- Make sure your fridge is at the right temperature.\n- Have a system to track leftovers in the fridge. \n- Eat all edible parts of food, including stems and skins.\n- Use creative recipes to use the entire food item.\n- Use food that would otherwise be thrown away for animal feed.\n- Take part in campaigns to make food sustainability an issue.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of methods to reduce food waste.\n\n### Response:\n- Shop smarter and only buy what you need.\n- Save leftovers for later and freeze them if necessary. \n- Monitor expiration dates and consume items before they go bad.\n- Participate in community or public food sharing initiatives.\n- Compost leftover food scraps.\n- Reuse containers, jars, and bags.\n- Make use of restaurant discounts to reduce over-ordering. \n- Donate foods that have short expiration dates to local food banks. \n- Buy more produce that is close to the expiration date and prepare it soon.\n- Order small portions while dining out.\n- Educate yourself and others about food waste.\n- Plan meals in advance and base your grocery list on the meals. \n- Store food correctly to reduce spoilage.\n- Make sure your fridge is at the right temperature.\n- Have a system to track leftovers in the fridge. \n- Eat all edible parts of food, including stems and skins.\n- Use creative recipes to use the entire food item.\n- Use food that would otherwise be thrown away for animal feed.\n- Take part in campaigns to make food sustainability an issue."}, {"instruction": "Propose an idea for a self-driving car.", "input": "", "output": "My idea for a self-driving car is to develop an AI system that uses machine learning to help the car recognize its environment. The AI system will also be able to learn from its mistakes and increase its accuracy. Additionally, the self-driving car will be able to detect potential hazards and warn the driver, or take corrective action if necessary. The car will also be equipped with an emergency button, in case the driver needs to take manual control of the vehicle. Furthermore, the car will be capable of automated parking and will be able to drive itself to a nearby charging station when the battery is low. Finally, the software will be able to connect with smartphones, so the user can access their navigation system and entertainment options even while on the move.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nPropose an idea for a self-driving car.\n\n### Response:\nMy idea for a self-driving car is to develop an AI system that uses machine learning to help the car recognize its environment. The AI system will also be able to learn from its mistakes and increase its accuracy. Additionally, the self-driving car will be able to detect potential hazards and warn the driver, or take corrective action if necessary. The car will also be equipped with an emergency button, in case the driver needs to take manual control of the vehicle. Furthermore, the car will be capable of automated parking and will be able to drive itself to a nearby charging station when the battery is low. Finally, the software will be able to connect with smartphones, so the user can access their navigation system and entertainment options even while on the move."}, {"instruction": "Explain the concept of a linear regression model using simple language.", "input": "", "output": "A linear regression model is a type of machine learning model that can be used to predict or estimate the value of a continuous numerical variable from a set of input variables. In this model, a line is fitted to a set of data points such that the sum of the squared distance of the data points from the line is minimized. This line of best fit can then be used to estimate the value of the output variable based on the values of the input variables.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the concept of a linear regression model using simple language.\n\n### Response:\nA linear regression model is a type of machine learning model that can be used to predict or estimate the value of a continuous numerical variable from a set of input variables. In this model, a line is fitted to a set of data points such that the sum of the squared distance of the data points from the line is minimized. This line of best fit can then be used to estimate the value of the output variable based on the values of the input variables."}, {"instruction": "Construct a list of tips to help reduce stress.", "input": "", "output": "- Take a time out and acquire some rest when feeling overwhelmed. \n- Make time for yourself and learn to relax. \n- Exercise regularly and get plenty of sleep. \n- Listen to music you like, or take a warm bath. \n- Stick to a healthy diet full of essential vitamins and minerals.\n- Reach out to friends and family for support.\n- Try aromatherapy or yoga. \n- Do something creative like drawing, painting, or writing. \n- Engage in activities that bring you joy and peace of mind. \n- Avoid drugs or excessive amounts of alcohol. \n- Practice deep breathing and progressive muscle relaxation. \n- Take up a new hobby to stay occupied. \n- Break big tasks down into smaller, more manageable tasks. \n- Practice positive self-talk to keep your morale high. \n- Talk to a therapist or trusted mentor if necessary. \n- Make time for fun activities and social events. \n- Prioritize tasks and create a weekly schedule. \n- Create a plan of action and identify solutions to problems. \n- Learn to forgive yourself and others.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nConstruct a list of tips to help reduce stress.\n\n### Response:\n- Take a time out and acquire some rest when feeling overwhelmed. \n- Make time for yourself and learn to relax. \n- Exercise regularly and get plenty of sleep. \n- Listen to music you like, or take a warm bath. \n- Stick to a healthy diet full of essential vitamins and minerals.\n- Reach out to friends and family for support.\n- Try aromatherapy or yoga. \n- Do something creative like drawing, painting, or writing. \n- Engage in activities that bring you joy and peace of mind. \n- Avoid drugs or excessive amounts of alcohol. \n- Practice deep breathing and progressive muscle relaxation. \n- Take up a new hobby to stay occupied. \n- Break big tasks down into smaller, more manageable tasks. \n- Practice positive self-talk to keep your morale high. \n- Talk to a therapist or trusted mentor if necessary. \n- Make time for fun activities and social events. \n- Prioritize tasks and create a weekly schedule. \n- Create a plan of action and identify solutions to problems. \n- Learn to forgive yourself and others."}, {"instruction": "Write a review of a recent movie you watched.", "input": "Parasite (2019)", "output": "I recently watched the movie Parasite, directed by Bong Joon-Ho, and I was absolutely blown away. This movie is a masterful blend of genres\u2014it's both a black comedy and a commentary on classism, with a twist ending that no one could see coming. The acting is fantastic, particularly Song Kang-ho as the head of the poor household, and the cinematography is stunning. The movie has a perfect balance of humor and thought-provoking drama, and it will leave you talking about the ending for days afterwards. Overall, I highly recommend Parasite; it is an exceptional movie and an experience you won't soon forget.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a review of a recent movie you watched.\n\n### Input:\nParasite (2019)\n\n### Response:\nI recently watched the movie Parasite, directed by Bong Joon-Ho, and I was absolutely blown away. This movie is a masterful blend of genres\u2014it's both a black comedy and a commentary on classism, with a twist ending that no one could see coming. The acting is fantastic, particularly Song Kang-ho as the head of the poor household, and the cinematography is stunning. The movie has a perfect balance of humor and thought-provoking drama, and it will leave you talking about the ending for days afterwards. Overall, I highly recommend Parasite; it is an exceptional movie and an experience you won't soon forget."}, {"instruction": "Create a Twitter post to promote your new product.", "input": "Alisha's All Natural Jam", "output": "Introducing Alisha's All Natural Jam, handmade with love using only the finest ingredients! Get ready for a delicious treat that will give your tastebuds a flavor explosion! Try our five amazing flavors: Strawberry, Blueberry, Peach, Raspberry, and Blackberry. Buy now and snack with the world! #snacktime #allnatural #jam #tastebuds #freshfruit #homemade #delicious #salttw", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Twitter post to promote your new product.\n\n### Input:\nAlisha's All Natural Jam\n\n### Response:\nIntroducing Alisha's All Natural Jam, handmade with love using only the finest ingredients! Get ready for a delicious treat that will give your tastebuds a flavor explosion! Try our five amazing flavors: Strawberry, Blueberry, Peach, Raspberry, and Blackberry. Buy now and snack with the world! #snacktime #allnatural #jam #tastebuds #freshfruit #homemade #delicious #salttw"}, {"instruction": "Given the following input, classify the verb as transitive or intransitive.", "input": "improve", "output": "Transitive.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGiven the following input, classify the verb as transitive or intransitive.\n\n### Input:\nimprove\n\n### Response:\nTransitive."}, {"instruction": "Generate a list of marketing strategies to promote a new mobile app.", "input": "", "output": "- Implement paid campaigns on social media platforms.\n- Utilize influencer marketing.\n- Reach out to tech bloggers and websites.\n- Share the app through word of mouth among friends.\n- Create demo videos and tutorials to explain the app's features.\n- Offer discounts and special incentives to early adopters.\n- Build a basic website to explain the app's purpose and features. \n- Take advantage of app store optimization techniques.\n- Send out emails and newsletters to inform new users about the app.\n- Aim for awards and good reviews on the app store.\n- Include the app link in your email signature.\n- Try promotional press releases and campaigns.\n- Encourage users to rate and review the app on the app store.\n- Leverage user-generated content to your advantage. \n- Integrate the app with existing or popular apps.\n- Hold giveaways or contests to drive engagement.\n- Feature the app on your company's website and blog.\n- Cross-promote the app with other apps, websites, and social media platforms. \n- Create promotional content such as blogs and podcasts.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of marketing strategies to promote a new mobile app.\n\n### Response:\n- Implement paid campaigns on social media platforms.\n- Utilize influencer marketing.\n- Reach out to tech bloggers and websites.\n- Share the app through word of mouth among friends.\n- Create demo videos and tutorials to explain the app's features.\n- Offer discounts and special incentives to early adopters.\n- Build a basic website to explain the app's purpose and features. \n- Take advantage of app store optimization techniques.\n- Send out emails and newsletters to inform new users about the app.\n- Aim for awards and good reviews on the app store.\n- Include the app link in your email signature.\n- Try promotional press releases and campaigns.\n- Encourage users to rate and review the app on the app store.\n- Leverage user-generated content to your advantage. \n- Integrate the app with existing or popular apps.\n- Hold giveaways or contests to drive engagement.\n- Feature the app on your company's website and blog.\n- Cross-promote the app with other apps, websites, and social media platforms. \n- Create promotional content such as blogs and podcasts."}, {"instruction": "Rewrite the sentence to provide more clarity and flow.", "input": "Making the decision to rent a house was a wise choice", "output": "Renting a house was a wise decision to make.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nRewrite the sentence to provide more clarity and flow.\n\n### Input:\nMaking the decision to rent a house was a wise choice\n\n### Response:\nRenting a house was a wise decision to make."}, {"instruction": "Give a description of the following customer segment.", "input": "Online shoppers in the US aged between 18-25", "output": "This customer segment consists of online shoppers in the US aged between 18-25. They are typically young adults who are tech-savvy and prefer to shop online for convenience and variety. They often shop for a variety of items, ranging from clothing to electronics and are comfortable using different payment methods such as credit or debit cards and digital wallets.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGive a description of the following customer segment.\n\n### Input:\nOnline shoppers in the US aged between 18-25\n\n### Response:\nThis customer segment consists of online shoppers in the US aged between 18-25. They are typically young adults who are tech-savvy and prefer to shop online for convenience and variety. They often shop for a variety of items, ranging from clothing to electronics and are comfortable using different payment methods such as credit or debit cards and digital wallets."}, {"instruction": "Design a logo for a mobile application.", "input": "", "output": "The logo design should be simple, modern, and unique. It should capture the essence of the application and communicate its functions. The logo should feature bold colors that are eye-catching, and the design should be adaptable to various sizes and platforms, such as websites and mobile devices.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDesign a logo for a mobile application.\n\n### Response:\nThe logo design should be simple, modern, and unique. It should capture the essence of the application and communicate its functions. The logo should feature bold colors that are eye-catching, and the design should be adaptable to various sizes and platforms, such as websites and mobile devices."}, {"instruction": "Generate a list of country names in the western hemisphere.", "input": "", "output": "- Canada\n- United States \n- Mexico \n- Brazil\n- Argentina \n- Chile \n- Colombia \n- Venezuela \n- Cuba\n- Jamaica \n- Ecuador\n- Peru \n- Uruguay\n- Paraguay \n- Costa Rica \n- Panama \n- Honduras \n- Nicaragua\n- Guatemala", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of country names in the western hemisphere.\n\n### Response:\n- Canada\n- United States \n- Mexico \n- Brazil\n- Argentina \n- Chile \n- Colombia \n- Venezuela \n- Cuba\n- Jamaica \n- Ecuador\n- Peru \n- Uruguay\n- Paraguay \n- Costa Rica \n- Panama \n- Honduras \n- Nicaragua\n- Guatemala"}, {"instruction": "Translate the following phrase into French.", "input": "I miss you", "output": "Je te manque.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nTranslate the following phrase into French.\n\n### Input:\nI miss you\n\n### Response:\nJe te manque."}, {"instruction": "Explain what an API is.", "input": "", "output": "API stands for Application Programming Interface. It is a set of instructions or programming instructions that allow two applications to communicate and exchange data with each other. APIs provide a way for developers to access and manipulate data and services offered by other applications or platforms. They are an essential part of modern web development and allow developers to develop complex applications and websites using modular components.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain what an API is.\n\n### Response:\nAPI stands for Application Programming Interface. It is a set of instructions or programming instructions that allow two applications to communicate and exchange data with each other. APIs provide a way for developers to access and manipulate data and services offered by other applications or platforms. They are an essential part of modern web development and allow developers to develop complex applications and websites using modular components."}, {"instruction": "Compute the area of a rectangle with length 10cm and width 5cm.", "input": "", "output": "The area of the rectangle is 50 cm2.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCompute the area of a rectangle with length 10cm and width 5cm.\n\n### Response:\nThe area of the rectangle is 50 cm2."}, {"instruction": "Find the capital of Spain.", "input": "", "output": "The capital of Spain is Madrid.", "text": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nFind the capital of Spain.\n\n### Response:\nThe capital of Spain is Madrid."}, {"instruction": "Classify the following statement as true or false.", "input": "The Supreme Court is the highest court in the US.", "output": "True.", "text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nClassify the following statement as true or false.\n\n### Input:\nThe Supreme Court is the highest court in the US.\n\n### Response:\nTrue."}] diff --git a/training/tests/assets/invalid_dummy_config.yaml b/training/tests/assets/invalid_dummy_config.yaml new file mode 100644 index 0000000..d40b764 --- /dev/null +++ b/training/tests/assets/invalid_dummy_config.yaml @@ -0,0 +1,8 @@ +test1: + _component_: torchtune.utils.get_dtype + dtype: fp32 + dummy: 3 +test2: + _component_: torchtune.utils.get_dtype + dtype: fp32 + dummy: 3 diff --git a/training/tests/assets/m.model b/training/tests/assets/m.model new file mode 100644 index 0000000..5adc0c2 Binary files /dev/null and b/training/tests/assets/m.model differ diff --git a/training/tests/assets/tiktoken_small.model b/training/tests/assets/tiktoken_small.model new file mode 100644 index 0000000..4bfad62 --- /dev/null +++ b/training/tests/assets/tiktoken_small.model @@ -0,0 +1,2000 @@ +AA== 0 +AQ== 1 +Ag== 2 +Aw== 3 +BA== 4 +BQ== 5 +Bg== 6 +Bw== 7 +CA== 8 +CQ== 9 +Cg== 10 +Cw== 11 +DA== 12 +DQ== 13 +Dg== 14 +Dw== 15 +EA== 16 +EQ== 17 +Eg== 18 +Ew== 19 +FA== 20 +FQ== 21 +Fg== 22 +Fw== 23 +GA== 24 +GQ== 25 +Gg== 26 +Gw== 27 +HA== 28 +HQ== 29 +Hg== 30 +Hw== 31 +IA== 32 +IQ== 33 +Ig== 34 +Iw== 35 +JA== 36 +JQ== 37 +Jg== 38 +Jw== 39 +KA== 40 +KQ== 41 +Kg== 42 +Kw== 43 +LA== 44 +LQ== 45 +Lg== 46 +Lw== 47 +MA== 48 +MQ== 49 +Mg== 50 +Mw== 51 +NA== 52 +NQ== 53 +Ng== 54 +Nw== 55 +OA== 56 +OQ== 57 +Og== 58 +Ow== 59 +PA== 60 +PQ== 61 +Pg== 62 +Pw== 63 +QA== 64 +QQ== 65 +Qg== 66 +Qw== 67 +RA== 68 +RQ== 69 +Rg== 70 +Rw== 71 +SA== 72 +SQ== 73 +Sg== 74 +Sw== 75 +TA== 76 +TQ== 77 +Tg== 78 +Tw== 79 +UA== 80 +UQ== 81 +Ug== 82 +Uw== 83 +VA== 84 +VQ== 85 +Vg== 86 +Vw== 87 +WA== 88 +WQ== 89 +Wg== 90 +Ww== 91 +XA== 92 +XQ== 93 +Xg== 94 +Xw== 95 +YA== 96 +YQ== 97 +Yg== 98 +Yw== 99 +ZA== 100 +ZQ== 101 +Zg== 102 +Zw== 103 +aA== 104 +aQ== 105 +ag== 106 +aw== 107 +bA== 108 +bQ== 109 +bg== 110 +bw== 111 +cA== 112 +cQ== 113 +cg== 114 +cw== 115 +dA== 116 +dQ== 117 +dg== 118 +dw== 119 +eA== 120 +eQ== 121 +eg== 122 +ew== 123 +fA== 124 +fQ== 125 +fg== 126 +fw== 127 +gA== 128 +gQ== 129 +gg== 130 +gw== 131 +hA== 132 +hQ== 133 +hg== 134 +hw== 135 +iA== 136 +iQ== 137 +ig== 138 +iw== 139 +jA== 140 +jQ== 141 +jg== 142 +jw== 143 +kA== 144 +kQ== 145 +kg== 146 +kw== 147 +lA== 148 +lQ== 149 +lg== 150 +lw== 151 +mA== 152 +mQ== 153 +mg== 154 +mw== 155 +nA== 156 +nQ== 157 +ng== 158 +nw== 159 +oA== 160 +oQ== 161 +og== 162 +ow== 163 +pA== 164 +pQ== 165 +pg== 166 +pw== 167 +qA== 168 +qQ== 169 +qg== 170 +qw== 171 +rA== 172 +rQ== 173 +rg== 174 +rw== 175 +sA== 176 +sQ== 177 +sg== 178 +sw== 179 +tA== 180 +tQ== 181 +tg== 182 +tw== 183 +uA== 184 +uQ== 185 +ug== 186 +uw== 187 +vA== 188 +vQ== 189 +vg== 190 +vw== 191 +wA== 192 +wQ== 193 +wg== 194 +ww== 195 +xA== 196 +xQ== 197 +xg== 198 +xw== 199 +yA== 200 +yQ== 201 +yg== 202 +yw== 203 +zA== 204 +zQ== 205 +zg== 206 +zw== 207 +0A== 208 +0Q== 209 +0g== 210 +0w== 211 +1A== 212 +1Q== 213 +1g== 214 +1w== 215 +2A== 216 +2Q== 217 +2g== 218 +2w== 219 +3A== 220 +3Q== 221 +3g== 222 +3w== 223 +4A== 224 +4Q== 225 +4g== 226 +4w== 227 +5A== 228 +5Q== 229 +5g== 230 +5w== 231 +6A== 232 +6Q== 233 +6g== 234 +6w== 235 +7A== 236 +7Q== 237 +7g== 238 +7w== 239 +8A== 240 +8Q== 241 +8g== 242 +8w== 243 +9A== 244 +9Q== 245 +9g== 246 +9w== 247 ++A== 248 ++Q== 249 ++g== 250 ++w== 251 +/A== 252 +/Q== 253 +/g== 254 +/w== 255 +IHQ= 256 +aGU= 257 +IGE= 258 +aW4= 259 +IHM= 260 +IHc= 261 +IHRoZQ== 262 +IG8= 263 +cmU= 264 +IGI= 265 +b3U= 266 +ZWQ= 267 +IG0= 268 +bmQ= 269 +IEk= 270 +aGE= 271 +aXQ= 272 +ZXI= 273 +aW5n 274 +IGY= 275 +aXM= 276 +IHRv 277 +ZW4= 278 +b24= 279 +b3I= 280 +YXM= 281 +IGM= 282 +IG9m 283 +IGFuZA== 284 +IGQ= 285 +bGw= 286 +YXQ= 287 +YW4= 288 +YXI= 289 +IHA= 290 +IG4= 291 +IGlu 292 +bGU= 293 +b20= 294 +b3Q= 295 +IGJl 296 +IGg= 297 +dXQ= 298 +b3c= 299 +ZXM= 300 +aGF0 301 +IGc= 302 +IGhl 303 +IGhh 304 +IGw= 305 +IHdhcw== 306 +bGQ= 307 +Z2g= 308 +aWQ= 309 +Y2g= 310 +IHRo 311 +IGl0 312 +YXk= 313 +IG9u 314 +Y2U= 315 +c2U= 316 +ZW50 317 +IHN0 318 +bHk= 319 +dmU= 320 +ZXQ= 321 +c3Q= 322 +IFQ= 323 +IGU= 324 +IHk= 325 +Z2h0 326 +aXI= 327 +IG1l 328 +b28= 329 +YWw= 330 +aXRo 331 +IHJl 332 +aW0= 333 +IHRoYXQ= 334 +IGFz 335 +b3VsZA== 336 +cm8= 337 +YWQ= 338 +aW9u 339 +Lgo= 340 +aGVy 341 +IG15 342 +Y3Q= 343 +IG5vdA== 344 +IHdpdGg= 345 +IGZvcg== 346 +IHU= 347 +a2U= 348 +IHlvdQ== 349 +IFM= 350 +IGlz 351 +aWdodA== 352 +Igo= 353 +YW0= 354 +aWM= 355 +dXI= 356 +IGF0 357 +Li4= 358 +YWM= 359 +dGVy 360 +IHdo 361 +IGFu 362 +IHdl 363 +IFRoZQ== 364 +aWY= 365 +IG9y 366 +IGJ1dA== 367 +dmVy 368 +ICI= 369 +IHI= 370 +b3V0 371 +b21l 372 +IGhhZA== 373 +cHA= 374 +cXU= 375 +IHN1 376 +IHRoaXM= 377 +cmVk 378 +YXJk 379 +IHNv 380 +ZWxs 381 +IHdvdWxk 382 +IGhpcw== 383 +IHNo 384 +aW5l 385 +cmE= 386 +IHNl 387 +IGJ5 388 +LiIK 389 +IFA= 390 +aGVu 391 +IEE= 392 +IGhhdmU= 393 +IGZy 394 +IHNh 395 +IEg= 396 +IG9uZQ== 397 +ZW0= 398 +a2Vk 399 +aXJ0 400 +ZWN0 401 +IGhpbQ== 402 +IGxp 403 +IGFi 404 +YXRpb24= 405 +aGluZw== 406 +dGhl 407 +IFI= 408 +IGxl 409 +c3M= 410 +IFc= 411 +Y3U= 412 +aWxs 413 +J3Q= 414 +YXJ0 415 +YWxs 416 +LAo= 417 +b3du 418 +b3Jl 419 +IGFsbA== 420 +IGs= 421 +IGdv 422 +aGlydA== 423 +YW5k 424 +IG91dA== 425 +YW1l 426 +YWlu 427 +IGlm 428 +IG5v 429 +IGRv 430 +IHRoZXk= 431 +b29s 432 +dW4= 433 +dG8= 434 +IHVw 435 +IFJlZA== 436 +IG5l 437 +IEs= 438 +IGZyb20= 439 +IFNoaXJ0 440 +IHdvcg== 441 +b25n 442 +IHRoZXJl 443 +IHNhaWQ= 444 +cmk= 445 +YW50 446 +IEI= 447 +IGFueQ== 448 +dWQ= 449 +aW5k 450 +IHdoaQ== 451 +YWI= 452 +b3VuZA== 453 +IGFib3V0 454 +IHRoZW0= 455 +Y3Vw 456 +YWs= 457 +IGRl 458 +IHRl 459 +IE0= 460 +YWtl 461 +Y3VwaW5l 462 +aWc= 463 +IHdlcmU= 464 +b3JjdXBpbmU= 465 +aWw= 466 +Y2hvb2w= 467 +IHJv 468 +b29k 469 +IGFyZQ== 470 +aXZl 471 +IGxpa2U= 472 +eW8= 473 +IGhvdQ== 474 +J3M= 475 +b25l 476 +dXM= 477 +ZWw= 478 +dWw= 479 +YWNr 480 +b3A= 481 +LCI= 482 +dGg= 483 +YWNoZXI= 484 +dW0= 485 +YW5n 486 +IGZh 487 +YWc= 488 +IHNjaG9vbA== 489 +IGo= 490 +dGU= 491 +b2s= 492 +ZXNz 493 +dXN0 494 +ZXJz 495 +Li4uLg== 496 +IEM= 497 +dGhlcg== 498 +aGFu 499 +IHdoZW4= 500 +IHNw 501 +IG1hbg== 502 +IGNhbg== 503 +b3VnaA== 504 +IHdobw== 505 +IGdldA== 506 +IGRpZA== 507 +IHBv 508 +Y2k= 509 +IGFs 510 +aXN0 511 +IGNvbQ== 512 +bGY= 513 +YXU= 514 +IFBvcmN1cGluZQ== 515 +IHdoaWNo 516 +dmVu 517 +IGFm 518 +d24= 519 +YXNz 520 +YmVy 521 +IGV4 522 +b3Vz 523 +ZXN0 524 +bG8= 525 +IHRy 526 +ZWxsb3c= 527 +IHNheQ== 528 +b3VnaHQ= 529 +IHJvb20= 530 +IHNvbWU= 531 +LS0= 532 +IE8= 533 +YXRl 534 +IHY= 535 +aGVk 536 +YXA= 537 +IHR3 538 +IGJlYw== 539 +cmVl 540 +amVjdA== 541 +a3M= 542 +IGNvbg== 543 +IGJlZW4= 544 +ZW50cw== 545 +aWRl 546 +IGNvdWxk 547 +IEc= 548 +ZXA= 549 +IHBybw== 550 +bnQ= 551 +IGhvdXNl 552 +IGFn 553 +IElm 554 +IGtu 555 +IGZlbGxvdw== 556 +IHdoYXQ= 557 +d2F5 558 +aXNo 559 +IGFt 560 +aXRl 561 +bmRlcg== 562 +aW1l 563 +IHBy 564 +IHRlYWNoZXI= 565 +YXJl 566 +IGJv 567 +IHNoZQ== 568 +IE4= 569 +aWNl 570 +YXN0 571 +dXJl 572 +aWU= 573 +IHN1Y2g= 574 +dXRlbg== 575 +dXRlbmJlcg== 576 +dXRlbmJlcmc= 577 +IHF1 578 +bG93bg== 579 +IHdy 580 +cHQ= 581 +IEhl 582 +IHN0dWQ= 583 +aGVyZQ== 584 +IG1vcmU= 585 +cnk= 586 +dHRlcg== 587 +IFk= 588 +IG1heQ== 589 +aXR5 590 +IGxvbw== 591 +IG90aGVy 592 +aGlz 593 +IFBybw== 594 +IHdpbGw= 595 +IEl0 596 +b3J0 597 +IHNob3VsZA== 598 +dmVyeQ== 599 +d2U= 600 +IHBs 601 +YXNo 602 +LiI= 603 +IGFwcA== 604 +IGRheQ== 605 +dXJu 606 +cG8= 607 +IGhlcg== 608 +ICA= 609 +bm90 610 +Y2s= 611 +IHVu 612 +aGk= 613 +dmluZw== 614 +IG9sZA== 615 +IHRpbWU= 616 +IlQ= 617 +IHdheQ== 618 +YWJsZQ== 619 +PyIK 620 +IENsb3du 621 +IG9ubHk= 622 +dWI= 623 +YWNo 624 +IG9mZg== 625 +IHRoYW4= 626 +YWxseQ== 627 +IHRoZWly 628 +YmU= 629 +a2luZw== 630 +b3RoZXI= 631 +YXJ5 632 +YW5z 633 +YXRlZA== 634 +c2VsZg== 635 +IGdvaW5n 636 +dWNo 637 +b2xs 638 +IGJhY2s= 639 +aXlv 640 +LXQ= 641 +YW5jZQ== 642 +YWRl 643 +IFByb2plY3Q= 644 +c3A= 645 +IHR3bw== 646 +IHRob3VnaHQ= 647 +c28= 648 +IHJpZ2h0 649 +IGhlYWQ= 650 +dmVk 651 +IEQ= 652 +IHByZQ== 653 +IHNlZQ== 654 +IHVz 655 +IHN0dWRlbnRz 656 +Y2lw 657 +IGRvbg== 658 +IG5pZ2h0 659 +aW5jaXA= 660 +IEtpeW8= 661 +cGw= 662 +YXJlZA== 663 +IEd1dGVuYmVyZw== 664 +IGNv 665 +IGhvdw== 666 +b21ldA== 667 +ZmY= 668 +Ikk= 669 +LC0t 670 +IGFza2Vk 671 +aW5jaXBhbA== 672 +ZXZlcg== 673 +IGFj 674 +IEY= 675 +IG1ha2U= 676 +aXR0 677 +IG1pZ2h0 678 +Z2U= 679 +bGVk 680 +IGFmdGVy 681 +aWdu 682 +IGdy 683 +IG1hZGU= 684 +ZGQ= 685 +IGtub3c= 686 +IGNvbWU= 687 +IGJy 688 +dGhpbmc= 689 +IEJ1dA== 690 +IG1hdA== 691 +IE9u 692 +b3J5 693 +Y2w= 694 +IEU= 695 +Ymxl 696 +b2c= 697 +IHlvdXI= 698 +dWxs 699 +IHdvcms= 700 +ZWFy 701 +IHRocmVl 702 +aWVk 703 +YnV0 704 +VGhl 705 +cGU= 706 +YWNl 707 +IHN0YXJ0 708 +aWNr 709 +IG92ZXI= 710 +b3Vy 711 +IG11Y2g= 712 +IHdhbnQ= 713 +aW1w 714 +IHBhcnQ= 715 +aG8= 716 +aW5r 717 +ZW5jZQ== 718 +IGRvd24= 719 +IGV2ZW4= 720 +IHByaW5jaXBhbA== 721 +bGluZw== 722 +b3VudA== 723 +YXVzZQ== 724 +IGNs 725 +IGJs 726 +LXRt 727 +b21ldGhpbmc= 728 +IGludG8= 729 +b3Jt 730 +b2t5bw== 731 +IGRpcw== 732 +IGZl 733 +IGZhY2U= 734 +Li4uLi4u 735 +cmVzcw== 736 +bWVudA== 737 +aXJl 738 +IGFy 739 +dHk= 740 +IG1v 741 +cmVhdA== 742 +IGZpcg== 743 +cGVy 744 +IG91cg== 745 +Y28= 746 +IHRoZW4= 747 +IHRvbGQ= 748 +aW5ncw== 749 +IHRha2U= 750 +IGJlZw== 751 +bmVy 752 +aXRpb24= 753 +b3Nl 754 +IG93bg== 755 +IGFnYWlu 756 +IHNlZW0= 757 +aXNl 758 +IHdhdA== 759 +Ilc= 760 +IGZhcg== 761 +YWtpbmc= 762 +Zm9yZQ== 763 +YWR5 764 +LXM= 765 +bGVzcw== 766 +IHJldA== 767 +IHNoYQ== 768 +IGNhbWU= 769 +Z2Vy 770 +IGdvb2Q= 771 +YXRoZXI= 772 +YXJr 773 +cm93 774 +IGtl 775 +J20= 776 +IGhhcw== 777 +YXRo 778 +cHBlZA== 779 +IHdlbnQ= 780 +IHRlbGw= 781 +cXVhc2g= 782 +IGVu 783 +IGZpcnN0 784 +IGhvdA== 785 +aXo= 786 +IGF3YXk= 787 +IHNvbWV0aGluZw== 788 +IHJlbQ== 789 +IHRvd24= 790 +IHNt 791 +IFRoaXM= 792 +IGJldHRlcg== 793 +IFRoZW4= 794 +d2Fz 795 +b2Y= 796 +YmFyZA== 797 +IEw= 798 +bGk= 799 +ZmU= 800 +IFRva3lv 801 +IGxvbmc= 802 +aWx5 803 +IHN1cmU= 804 +IGxvb2tlZA== 805 +dWJiYXJk 806 +Y3Rpb24= 807 +b3Jk 808 +IG1hbnk= 809 +aW91cw== 810 +IHRvbw== 811 +IGhlcmU= 812 +b3M= 813 +IHVuZGVy 814 +YXNl 815 +bmc= 816 +cGVk 817 +b2Q= 818 +bWU= 819 +IGp1c3Q= 820 +IG5vdw== 821 +aW5jZQ== 822 +IGhlYXJk 823 +IGtpbmQ= 824 +IFRoZXk= 825 +IGJlZm9yZQ== 826 +aHk= 827 +IElu 828 +IGVudA== 829 +IGJvYXJk 830 +ISI= 831 +d2FyZA== 832 +IGJlaW5n 833 +IHdlbGw= 834 +ZXJt 835 +cmllZA== 836 +IHdyb25n 837 +YWlk 838 +eHQ= 839 +IHJldHVybg== 840 +aXRlZA== 841 +IHllbg== 842 +IG1hdHRlcg== 843 +IGNhbGw= 844 +IHRhbA== 845 +IFlvdQ== 846 +Y2Vk 847 +aXNlZA== 848 +IGNoYQ== 849 +b25z 850 +IHNhbWU= 851 +IG9uY2U= 852 +ZGF5 853 +ZnQ= 854 +IHN3 855 +IGJlY2F1c2U= 856 +IHRoaW5r 857 +IHdoZXJl 858 +IE5v 859 +IEh1YmJhcmQ= 860 +IFNxdWFzaA== 861 +IGNvcA== 862 +d2l0aA== 863 +ZXJlZA== 864 +b2xsb3c= 865 +IHBsYWNl 866 +aWRk 867 +Y2Vzcw== 868 +IHNob3c= 869 +aXNoYQ== 870 +IHJh 871 +IGxldHRlcg== 872 +bmU= 873 +dmVz 874 +YXRpbmc= 875 +cmFuZw== 876 +IGFmZg== 877 +IGhhbmQ= 878 +IHNj 879 +IHBlcnM= 880 +aW50 881 +cHI= 882 +c2lkZQ== 883 +ZnRlcg== 884 +IHNheWluZw== 885 +IGxhdQ== 886 +dGhhdA== 887 +IHdpdGhvdXQ= 888 +cm9u 889 +YWly 890 +bGVjdA== 891 +IFdoYXQ= 892 +ZWx0 893 +IHdoaWxl 894 +b2dh 895 +YXBlcg== 896 +IHBl 897 +b3k= 898 +IHNhdA== 899 +aWVz 900 +IGFkZA== 901 +IGRheXM= 902 +IHNwZQ== 903 +IGhv 904 +IGFucw== 905 +IGhhcg== 906 +IFdoZW4= 907 +IGFueXRoaW5n 908 +cGVu 909 +XQo= 910 +dGFpbg== 911 +IG11c3Q= 912 +IG5ldw== 913 +bGlj 914 +IHZv 915 +aGlsZQ== 916 +Z2V0 917 +IEFz 918 +IHZlcnk= 919 +J3Jl 920 +IGV2ZXJ5 921 +YXZl 922 +PyI= 923 +YWRnZXI= 924 +IEtvZ2E= 925 +IE1y 926 +cm91Z2g= 927 +dWx0 928 +IGZvbGxvdw== 929 +dGluZw== 930 +aWZl 931 +aWRkbGU= 932 +ZnVs 933 +YW5r 934 +IFNv 935 +IHNlZW1lZA== 936 +IEFuZA== 937 +aXg= 938 +IHNldA== 939 +IGNhcmU= 940 +IHJlcw== 941 +IG5ldmVy 942 +IGZvdW5k 943 +IGxv 944 +Y2lk 945 +aW5lZA== 946 +IGNsYXNz 947 +IG15c2VsZg== 948 +YXc= 949 +IHdvbQ== 950 +YXRpb25z 951 +IGxlZnQ= 952 +IFdl 953 +IHRlYWNoZXJz 954 +Ilk= 955 +bmE= 956 +b250 957 +IGRlcw== 958 +IHRob3Nl 959 +aXJlZA== 960 +IHNlbg== 961 +eWluZw== 962 +IHRoZXNl 963 +YXo= 964 +IFRoZXJl 965 +Y2VwdA== 966 +IGRhbmc= 967 +IFU= 968 +Ikg= 969 +Ym9k 970 +Ym9keQ== 971 +IGhhdmluZw== 972 +YWxhcnk= 973 +IHdhdGNo 974 +IGdpdmU= 975 +YWdl 976 +IGl0cw== 977 +IGFwcGU= 978 +dWU= 979 +IGNvdW50 980 +IGhhcmQ= 981 +IGJlbA== 982 +b3R0 983 +IGRpc3Q= 984 +IlM= 985 +IE1hZA== 986 +LW4= 987 +cmlidXQ= 988 +Z2Vk 989 +IGF0dA== 990 +ZmVyZQ== 991 +aXRoZXI= 992 +IHVwb24= 993 +IHRlbQ== 994 +IHBlcnNvbg== 995 +bmluZw== 996 +IGNoZQ== 997 +YXJseQ== 998 +b25leQ== 999 +IHNvb24= 1000 +ZW1lbnQ= 1001 +ICg= 1002 +IHRyYW5z 1003 +IGV4cA== 1004 +IHNlcg== 1005 +IHJlZw== 1006 +YXNvbg== 1007 +IHNhdw== 1008 +IG5leHQ= 1009 +b290 1010 +IGhhbGY= 1011 +IHRvb2s= 1012 +IGJhZA== 1013 +IGhvdXI= 1014 +IHNhbGFyeQ== 1015 +IGJlZ2Fu 1016 +cmlnaHQ= 1017 +b25uYQ== 1018 +LXNhbg== 1019 +IHdvcmtz 1020 +IEo= 1021 +Zm9ybQ== 1022 +aWNhbA== 1023 +IHRyYQ== 1024 +bWFu 1025 +IG5vdGhpbmc= 1026 +IHN0aWxs 1027 +ZWFycw== 1028 +IHN1cHA= 1029 +IHR1cm4= 1030 +IGZlbHQ= 1031 +IHdvbWFu 1032 +IHN0YXJ0ZWQ= 1033 +b3VibGU= 1034 +dXJh 1035 +aXNoaW5n 1036 +Ogo= 1037 +bGVjdHJvbg== 1038 +bGVjdHJvbmlj 1039 +b29r 1040 +IGNvcHk= 1041 +IGZ1bGw= 1042 +Y29uZA== 1043 +bWF0 1044 +IG1pZGRsZQ== 1045 +IGxvb2s= 1046 +IGNvbW0= 1047 +d2VyZWQ= 1048 +IGJlY2FtZQ== 1049 +IGZlbGxvd3M= 1050 +d291bGQ= 1051 +IGdvdA== 1052 +IGds 1053 +IGd1 1054 +IGtlZXA= 1055 +IGdl 1056 +IE1hZG9ubmE= 1057 +aXRlcg== 1058 +aXNoZWQ= 1059 +IHVuZGVyc3Q= 1060 +IHN0cmE= 1061 +c2lk 1062 +IGNvdW50cnk= 1063 +b3BsZQ== 1064 +IHByb3Y= 1065 +IHB1dA== 1066 +bm8= 1067 +J2xs 1068 +IHNsZQ== 1069 +cmFuZ2U= 1070 +IFNoZQ== 1071 +cG9z 1072 +IG1pbmQ= 1073 +IHBhc3M= 1074 +IHRocm91Z2g= 1075 +IHF1aXRl 1076 +IGluZA== 1077 +IGJvYXJkaW5n 1078 +dGVhY2hlcg== 1079 +cGxl 1080 +UG9yY3VwaW5l 1081 +IHBsZQ== 1082 +IGdlaXNoYQ== 1083 +ICAgIA== 1084 +b3N0 1085 +ZW5zZQ== 1086 +Tm8= 1087 +aWJsZQ== 1088 +IHJlYWQ= 1089 +IHJlZA== 1090 +ZW50aW9u 1091 +ZW5lZA== 1092 +ISIK 1093 +IHJlZg== 1094 +IGFk 1095 +IGZs 1096 +IHN0YXk= 1097 +dXA= 1098 +IHJvdW5k 1099 +IGNsZQ== 1100 +IG9wZW4= 1101 +IG9i 1102 +dGVuZA== 1103 +IGZpbmQ= 1104 +IHBlcg== 1105 +IGNhbGxlZA== 1106 +IHN1cg== 1107 +cmV3 1108 +IHBhcGVy 1109 +IEJhZGdlcg== 1110 +IG1lZXQ= 1111 +aXNz 1112 +IlRoYXQ= 1113 +ZXJtcw== 1114 +VEU= 1115 +aXR0ZW4= 1116 +YWJseQ== 1117 +bmVzcw== 1118 +IGNhbm5vdA== 1119 +IHNpbXA= 1120 +Y29u 1121 +IHJlYXNvbg== 1122 +eW91 1123 +IGhvbWU= 1124 +Ynk= 1125 +IGZpZ2h0 1126 +aXR0bGU= 1127 +IHRoaW5ncw== 1128 +IGVhcw== 1129 +IGltcA== 1130 +cmVzc2Vk 1131 +IG1lYW4= 1132 +IGFwcGVhcmVk 1133 +IG5hdA== 1134 +IGhlbA== 1135 +cmV0 1136 +YWtlbg== 1137 +IHN0cmFpZ2h0 1138 +IGFmZmFpcg== 1139 +aXRpbmc= 1140 +IGVk 1141 +IHNpbmNl 1142 +bG9n 1143 +IHBheQ== 1144 +IGZyb250 1145 +bXk= 1146 +IHZvaWNl 1147 +cmVhZHk= 1148 +IGZvb2w= 1149 +b3VuZGF0aW9u 1150 +IGVsZWN0cm9uaWM= 1151 +IHRlcm1z 1152 +IG1hcg== 1153 +YXBhbg== 1154 +YW55 1155 +IHJlc3A= 1156 +IGVuZA== 1157 +YXBw 1158 +d2hhdA== 1159 +c3Ry 1160 +cmFw 1161 +aWFs 1162 +aWN1bA== 1163 +IGFjYw== 1164 +b3Ro 1165 +IHNlY29uZA== 1166 +IGZsbw== 1167 +IHNpeA== 1168 +IGZlZXQ= 1169 +YnI= 1170 +aWV0 1171 +IGxpdHRsZQ== 1172 +bGVz 1173 +IG1vbmV5 1174 +IGRlY2w= 1175 +IGV5 1176 +IGNvbXA= 1177 +YXJpbmc= 1178 +IGFncmU= 1179 +d2hlcmU= 1180 +IFN0 1181 +IHN0cmU= 1182 +ZXg= 1183 +cmFjdA== 1184 +IGludA== 1185 +IGRpcmU= 1186 +IGJlY29tZQ== 1187 +IGhvbg== 1188 +IGNvbnNpZA== 1189 +ZXJ0YWlu 1190 +bm93 1191 +IHNs 1192 +aXRvcg== 1193 +Z2c= 1194 +IGp1bQ== 1195 +IGJ1 1196 +IHRoaW5n 1197 +IGFuc3dlcmVk 1198 +b2Vz 1199 +eWE= 1200 +IFRoYXQ= 1201 +aXpl 1202 +b25k 1203 +YWN0 1204 +IGVmZg== 1205 +IGJhbmc= 1206 +YWJvdXQ= 1207 +IGJlZA== 1208 +b3Jyb3c= 1209 +dW5n 1210 +IFRv 1211 +IGtlcHQ= 1212 +IHdhbA== 1213 +IGJhdGg= 1214 +IGRyYQ== 1215 +IkE= 1216 +cmluZ3M= 1217 +aG9wcA== 1218 +IHJlc2lnbg== 1219 +IGRpbg== 1220 +IGxhZHk= 1221 +LkU= 1222 +IHVzZQ== 1223 +bGlzaA== 1224 +b3Jz 1225 +IHdyaXR0ZW4= 1226 +ZW5l 1227 +aXY= 1228 +IGRpZg== 1229 +IHN0ZQ== 1230 +IHN0b3J5 1231 +Y29t 1232 +cmVz 1233 +ZW50bHk= 1234 +IGZhY3Q= 1235 +aGVz 1236 +d2F5cw== 1237 +IHdoeQ== 1238 +IHRob3VnaA== 1239 +IHN0cg== 1240 +b25kZXI= 1241 +aGVhZA== 1242 +IGNvdXI= 1243 +IG1vbg== 1244 +IHNr 1245 +IGJlbGll 1246 +IGxldA== 1247 +ZmVy 1248 +IHJlcXU= 1249 +IGxpbmU= 1250 +cm9vbQ== 1251 +LWRheQ== 1252 +IGRvbmU= 1253 +IGRvZXM= 1254 +IE9uZQ== 1255 +IGRhbmdv 1256 +YXNzaG9wcA== 1257 +IGNvbnNpZGVy 1258 +IGRpbm5lcg== 1259 +IEZvdW5kYXRpb24= 1260 +Kio= 1261 +ZW1wdA== 1262 +ZXNl 1263 +IHdvcmQ= 1264 +cmVzdA== 1265 +IGVub3VnaA== 1266 +IGdyZWF0 1267 +IG5hbWU= 1268 +IHB1Yg== 1269 +IG1hbm5lcg== 1270 +d2Vy 1271 +aWN0 1272 +aW5lc3M= 1273 +IGhpbXNlbGY= 1274 +IHBlb3BsZQ== 1275 +ZXc= 1276 +IGNvcg== 1277 +ZXN0aW9u 1278 +IGJpZw== 1279 +ZWU= 1280 +IHJp 1281 +aWRlcw== 1282 +IGJyb3RoZXI= 1283 +IGhlYXJ0 1284 +ZWN0ZWQ= 1285 +ZWVk 1286 +IG90aGVycw== 1287 +c29s 1288 +dGVk 1289 +IGV5ZXM= 1290 +IHRyb3VibGU= 1291 +IHRlYWNo 1292 +IGJvYXQ= 1293 +IGZvdXI= 1294 +IGFscmVhZHk= 1295 +cm9t 1296 +Z2hlZA== 1297 +IHNxdQ== 1298 +IHBvbA== 1299 +Y2Vz 1300 +IEhvdHQ= 1301 +IGxlYXZl 1302 +IGRpc3RyaWJ1dA== 1303 +YXN0ZXI= 1304 +Q0g= 1305 +dWM= 1306 +IGlt 1307 +IGhvd2V2ZXI= 1308 +dGhlcmU= 1309 +YXBhbmVzZQ== 1310 +IGxhc3Q= 1311 +IGNy 1312 +aWxpdHk= 1313 +IHNpbXBsZQ== 1314 +IGxpZmU= 1315 +LWM= 1316 +IHJlZ2FyZA== 1317 +IGZpbg== 1318 +dWFs 1319 +IG1lYW5z 1320 +IHN0YW5k 1321 +YXRjaA== 1322 +IHNob3J0 1323 +bmVk 1324 +IHNlZW4= 1325 +IGhhcHA= 1326 +LWs= 1327 +IGFnYWluc3Q= 1328 +aGlt 1329 +YW1lZA== 1330 +IHN0b29k 1331 +IGdyYQ== 1332 +IG1vdGhlcg== 1333 +IGZpc2g= 1334 +IHdhdGVy 1335 +YWls 1336 +Y2Vp 1337 +IHJhdGhlcg== 1338 +IGlucw== 1339 +IGZlZWw= 1340 +IGFsc28= 1341 +IG9yZA== 1342 +IGNvbWluZw== 1343 +aWNz 1344 +IGVpdGhlcg== 1345 +bmNl 1346 +ICc= 1347 +IGtpZA== 1348 +IGxhdWdoZWQ= 1349 +bGlrZQ== 1350 +IEFy 1351 +Z3I= 1352 +IEhvdHRh 1353 +IHRhbGs= 1354 +Z2V0aGVy 1355 +IFNpcg== 1356 +IHB1bg== 1357 +UHJv 1358 +YXRz 1359 +bW9zdA== 1360 +IHJlcA== 1361 +IGdp 1362 +aXNm 1363 +YmFibHk= 1364 +YWtlcw== 1365 +IE5vdA== 1366 +bnk= 1367 +IGFwcGVhcg== 1368 +bXA= 1369 +Y2hh 1370 +IGFjdA== 1371 +YmVk 1372 +aWVm 1373 +dWZm 1374 +IGFwbw== 1375 +IG1ldA== 1376 +IHJldHVybmVk 1377 +IHNvdW5k 1378 +dXNpbmVzcw== 1379 +IGxhdWdo 1380 +IGNsZWFy 1381 +IG5lZWQ= 1382 +ZmVzcw== 1383 +ZXN0ZWQ= 1384 +IGludg== 1385 +IGFjY2VwdA== 1386 +dW5kZXI= 1387 +Owo= 1388 +IHN1cnBy 1389 +ZGU= 1390 +IHRyYWlu 1391 +IGhvdGVs 1392 +IHNsZWVw 1393 +IGRy 1394 +IGhvbGQ= 1395 +bG9jaw== 1396 +cHVyYQ== 1397 +IHNwcmluZ3M= 1398 +IC4uLi4uLg== 1399 +IGFncmVlbWVudA== 1400 +IERhcg== 1401 +IHJlc3Q= 1402 +Y2x1ZA== 1403 +YXRvcg== 1404 +YXY= 1405 +IG9yaWc= 1406 +IG9yaWdpbg== 1407 +IGVs 1408 +IG5vcg== 1409 +IHByZXM= 1410 +IHVuZGVyc3RhbmQ= 1411 +IHRha2Vu 1412 +IGxpZ2h0 1413 +ZW5lcg== 1414 +c29tZQ== 1415 +IGJyb3VnaHQ= 1416 +cmFwaA== 1417 +IG1vc3Q= 1418 +b2tl 1419 +LXc= 1420 +IHVudA== 1421 +IGZhdGhlcg== 1422 +IHVzZWQ= 1423 +IGVhdA== 1424 +IHllYXJz 1425 +IFdoaWxl 1426 +IGNoYW4= 1427 +IHN1ZGQ= 1428 +IHN1ZGRlbg== 1429 +IGFwb2xvZw== 1430 +IHNldHQ= 1431 +IHRoaW4= 1432 +IE15 1433 +IHRlbg== 1434 +aW1lcw== 1435 +Zm9y 1436 +b3Vk 1437 +V2hlbg== 1438 +IGRldA== 1439 +IGxpdmU= 1440 +IG9j 1441 +IGZpdmU= 1442 +IGNvbnQ= 1443 +IGhlbHA= 1444 +IHdh 1445 +IHBhc3NlZA== 1446 +IHJ1bg== 1447 +IG1ha2luZw== 1448 +IHN0cmFuZ2U= 1449 +IHRha2luZw== 1450 +IGVhY2g= 1451 +IllvdQ== 1452 +IGFub3RoZXI= 1453 +IlNheQ== 1454 +IlRoZQ== 1455 +YXRlcw== 1456 +IHBsZWFz 1457 +YXNzaG9wcGVycw== 1458 +IG1vbQ== 1459 +IG1vbWVudA== 1460 +ZW50bGU= 1461 +bmdsaXNo 1462 +Q0hB 1463 +IG9yaWdpbmFs 1464 +aW9ucw== 1465 +dXJpbmc= 1466 +IHB1YmxpYw== 1467 +dWN0 1468 +dWNr 1469 +IHF1ZXN0aW9u 1470 +YWk= 1471 +Y3k= 1472 +ZWs= 1473 +IGZsb29y 1474 +IGNhcg== 1475 +b3VzZQ== 1476 +IHNpZGU= 1477 +LXlh 1478 +IGNlcnRhaW4= 1479 +aHlz 1480 +LWQ= 1481 +aWdo 1482 +YWdpbg== 1483 +d2VldA== 1484 +IHBvb3I= 1485 +IGRlY2lk 1486 +dWFsbHk= 1487 +IGJ1c2luZXNz 1488 +cHJv 1489 +cGxhaW4= 1490 +IHN0b3A= 1491 +IQo= 1492 +IEhvdw== 1493 +IldoYXQ= 1494 +Y2Fu 1495 +IFVu 1496 +cHM= 1497 +dW5k 1498 +LW5pZ2h0 1499 +IG1lZXRpbmc= 1500 +ZWRv 1501 +IHJhaXNl 1502 +R3V0ZW5iZXJn 1503 +IERhcmxpbmc= 1504 +dW1l 1505 +IEVuZ2xpc2g= 1506 +VEVS 1507 +YWRpbmc= 1508 +IHRyYW5zbA== 1509 +IGFibGU= 1510 +c3NpYmxl 1511 +IHNhdGlzZg== 1512 +IHdhbnRlZA== 1513 +IHN1Yg== 1514 +IGNhc2U= 1515 +aWZpYw== 1516 +aXRlcmFyeQ== 1517 +IG1haWQ= 1518 +IGluYw== 1519 +IHBvcw== 1520 +IHBvc2l0aW9u 1521 +IHBhdA== 1522 +dXJlZA== 1523 +b3JyeQ== 1524 +IGFjY291bnQ= 1525 +IGJvdGg= 1526 +IGZyaWU= 1527 +IGZyaWVuZA== 1528 +dGhpcw== 1529 +IGFsd2F5cw== 1530 +IHBhcnRpY3Vs 1531 +V2hhdA== 1532 +IHNtYWxs 1533 +ZW50eQ== 1534 +dXNoZWQ= 1535 +IG1pcw== 1536 +dWxseQ== 1537 +IHJlY2Vp 1538 +WW91 1539 +IHlldA== 1540 +IGdhdmU= 1541 +QnV0 1542 +aGFk 1543 +IGFuc3dlcg== 1544 +IGFicw== 1545 +aWxl 1546 +Y2tldA== 1547 +IG5vb2Q= 1548 +IGNvdXJzZQ== 1549 +IGZvcm0= 1550 +IGV2ZXJ5dGhpbmc= 1551 +ZWN0aW9u 1552 +SWY= 1553 +cGFydA== 1554 +IHNpbmc= 1555 +IHNpdA== 1556 +IHB1cg== 1557 +aXA= 1558 +IGZpc2hpbmc= 1559 +IGVo 1560 +IHBhcg== 1561 +IHRvZ2V0aGVy 1562 +SGU= 1563 +IHdoZQ== 1564 +IHdoZXRoZXI= 1565 +IGJyYQ== 1566 +Illlcw== 1567 +IHB1bmlzaA== 1568 +U2hpcnQ= 1569 +IFllZG8= 1570 +IGZhcmV3 1571 +IGZhcmV3ZWxs 1572 +IGRhbmNl 1573 +IGxlc3M= 1574 +dXJhbA== 1575 +IGRlZg== 1576 +IGF0dGVtcHQ= 1577 +d2Vlbg== 1578 +IHNpZ24= 1579 +IHN5 1580 +ZmVyZW50 1581 +IGxlYXN0 1582 +c2Vy 1583 +b2I= 1584 +bmRpbmc= 1585 +IHNvcnJ5 1586 +IGp1bXBlZA== 1587 +IGphbg== 1588 +IGphbml0b3I= 1589 +aXplZA== 1590 +IHRvd2FyZA== 1591 +IG1vcg== 1592 +YXZpbmc= 1593 +IGJpdA== 1594 +IlRoaXM= 1595 +IHJlbWFyaw== 1596 +IGZ1dA== 1597 +IHdvbmRlcg== 1598 +IGZ1bg== 1599 +VGhlbg== 1600 +IGRlYw== 1601 +IHdob20= 1602 +IGRpZG4= 1603 +IHJlYw== 1604 +YmVj 1605 +Iklm 1606 +IGtuZXc= 1607 +YWZ0ZXI= 1608 +IHRodXM= 1609 +IGlzbg== 1610 +IHNpZ2h0 1611 +bWVk 1612 +W0Y= 1613 +dXNz 1614 +Y2lkZW50 1615 +dGhlbQ== 1616 +IGZpZg== 1617 +IGRyYXc= 1618 +IGhlYXI= 1619 +IHdyaXRpbmc= 1620 +IGdldHRpbmc= 1621 +c2g= 1622 +ZmVyZW5jZQ== 1623 +IHJhaXNlZA== 1624 +dGhleQ== 1625 +YXg= 1626 +IGZpbmU= 1627 +c2Vs 1628 +IE5vYmU= 1629 +IE5vYmVvaw== 1630 +IE5vYmVva2E= 1631 +b3JtYWw= 1632 +IGVC 1633 +aWNlbnNl 1634 +MDA= 1635 +IGJlc3Q= 1636 +d29y 1637 +Zmlj 1638 +dGVyZXN0 1639 +IHJlbWFy 1640 +Ymw= 1641 +YXJ0ZWQ= 1642 +IGRhcms= 1643 +IHlvdW5n 1644 +dXNo 1645 +IGJldA== 1646 +b3V0aA== 1647 +aG91c2U= 1648 +YXVnaHQ= 1649 +IHBoeXM= 1650 +IHN0cm9uZw== 1651 +IGZ1cg== 1652 +IHJvbGw= 1653 +Y292ZQ== 1654 +Y2hpZWY= 1655 +YXdh 1656 +IGZvbGxvd2Vk 1657 +IGZvbmQ= 1658 +IGZ1dHVyZQ== 1659 +aXJk 1660 +ZnVsbHk= 1661 +IGVmZm9ydA== 1662 +QWZ0ZXI= 1663 +b3dhcmQ= 1664 +IHJlYWxseQ== 1665 +IGFtb25n 1666 +IGFyb3VuZA== 1667 +IGNvbXBs 1668 +IGdheg== 1669 +IGJvdw== 1670 +YXRlcg== 1671 +IGluc2lzdA== 1672 +IHR1cm5lZA== 1673 +aGVs 1674 +cmVt 1675 +IGhvdXJz 1676 +IGRlY2lkZWQ= 1677 +eXM= 1678 +IG1vbnRo 1679 +LWE= 1680 +IGFkdg== 1681 +IGJlbGlldmU= 1682 +IHRlYWNoaW5n 1683 +IGVhc3k= 1684 +IGRpcmVjdGlvbg== 1685 +b29rZWQ= 1686 +IHdhcg== 1687 +IHVubGVzcw== 1688 +aGF2ZQ== 1689 +IHNxdWFyZQ== 1690 +dmls 1691 +IHF1aWV0 1692 +IGh1bmc= 1693 +IGdvZXM= 1694 +IHBhaWQ= 1695 +IHNoYWxs 1696 +Ik5v 1697 +IHB1bmlzaG1lbnQ= 1698 +cG9zZQ== 1699 +IHN3ZWV0 1700 +J3Zl 1701 +IldlbGw= 1702 +IGdlbnRsZQ== 1703 +IG5vcm1hbA== 1704 +YWdyYXBo 1705 +Y2hpdmU= 1706 +Y2hhbg== 1707 +IGluY2x1ZA== 1708 +d3c= 1709 +b3Jn 1710 +dGVt 1711 +QVI= 1712 +IFRI 1713 +IGVxdQ== 1714 +IHRvbmU= 1715 +IHBvc3NpYmxl 1716 +IGJlY29t 1717 +IEphcGFuZXNl 1718 +dmVycw== 1719 +IGZvbGxvd2luZw== 1720 +IHBhaW4= 1721 +IHdob2xl 1722 +d3I= 1723 +IHNlcmlvdXM= 1724 +IG5hcg== 1725 +IHRpcmVk 1726 +SW4= 1727 +IHBsYXk= 1728 +IHByb20= 1729 +IGdhbWU= 1730 +IFNvbWU= 1731 +IGhhcHBlbmVk 1732 +IGN1dA== 1733 +IHR3ZW50eQ== 1734 +IGRvb3I= 1735 +IG1vcm5pbmc= 1736 +aGluZA== 1737 +IGJyZQ== 1738 +IGluc2lkZQ== 1739 +b3Zl 1740 +YWx0aA== 1741 +dWs= 1742 +YXJnZQ== 1743 +YW1i 1744 +IGRhbQ== 1745 +IHdvcnJ5 1746 +YXRpdmU= 1747 +IGV4cGVjdGVk 1748 +IGZhbQ== 1749 +IHByYQ== 1750 +IHBvY2tldA== 1751 +b29rcw== 1752 +Y2hlZA== 1753 +IHNpbA== 1754 +b2w= 1755 +IGZhdg== 1756 +IGVsc2U= 1757 +IGhpZ2g= 1758 +IHJlYWw= 1759 +IGFsb25n 1760 +IG1lZA== 1761 +aGlr 1762 +aGVtYXQ= 1763 +aGVtYXRpY3M= 1764 +IGxpc3Q= 1765 +IHNpY2s= 1766 +b2ludA== 1767 +W0Zvb3Q= 1768 +W0Zvb3Rub3Q= 1769 +W0Zvb3Rub3Rl 1770 +Ll0K 1771 +bmlnaHQ= 1772 +c2Vz 1773 +aW9y 1774 +IHNheXM= 1775 +IG1vdXRo 1776 +aG93 1777 +bWluZw== 1778 +IGNsbw== 1779 +IGN1cg== 1780 +Z2luZw== 1781 +IHN1ZGRlbmx5 1782 +LWFo 1783 +YW1w 1784 +IGJsYWNr 1785 +cm9zcw== 1786 +IGZhYw== 1787 +c2VsdmVz 1788 +aWV3 1789 +aXNzaW9u 1790 +IGNvcHlyaWdodA== 1791 +IHBhcmFncmFwaA== 1792 +IEFyY2hpdmU= 1793 +IGRvbmF0aW9ucw== 1794 +UHJvamVjdA== 1795 +IGNvc3Q= 1796 +Lm9yZw== 1797 +TEk= 1798 +dWNlZA== 1799 +IHN1Yw== 1800 +eWxl 1801 +IGZvcmNl 1802 +am95 1803 +b3VjaA== 1804 +dHI= 1805 +SXQ= 1806 +IHRyYWQ= 1807 +IHByZXNlbnQ= 1808 +IGV4dA== 1809 +YXNlZA== 1810 +cmVkaXQ= 1811 +IGZhdWx0 1812 +aWI= 1813 +LW0= 1814 +dXJk 1815 +IHRyaWVk 1816 +dGltZQ== 1817 +IHByZXQ= 1818 +IHNwZWU= 1819 +b3dlcg== 1820 +IHdvcmRz 1821 +Q0hBUA== 1822 +Q0hBUFRFUg== 1823 +c2Nob29s 1824 +IGFzaw== 1825 +IGRvaW5n 1826 +YXRlbHk= 1827 +IHVudGls 1828 +Ym91dA== 1829 +IHRyZWU= 1830 +Y2FsbA== 1831 +YW1hc2g= 1832 +YW1hc2hpcg== 1833 +YW1hc2hpcm8= 1834 +c3Rl 1835 +IGJlaGluZA== 1836 +b2xk 1837 +IHdhbGw= 1838 +aXRvcnk= 1839 +IHJvbGxlZA== 1840 +IG1vdmU= 1841 +IGFwb2xvZ2l6ZQ== 1842 +IGxhcmdl 1843 +YW1ib28= 1844 +c3U= 1845 +IHNldHRsZWQ= 1846 +Ikhl 1847 +d28= 1848 +IHRoaW5raW5n 1849 +dXNlZA== 1850 +aWZpZWQ= 1851 +IGFsbW9zdA== 1852 +IHRyZQ== 1853 +IHRyZWF0 1854 +IG5vb2RsZQ== 1855 +IG5vdGU= 1856 +IEFsbA== 1857 +IGJlYXQ= 1858 +IG9iamVjdA== 1859 +IHNlZW1z 1860 +IGlkZQ== 1861 +WWVz 1862 +b3dz 1863 +IHJlbWFpbg== 1864 +IGJlZ2lu 1865 +dWdodA== 1866 +bWVudHM= 1867 +IGFsb25l 1868 +c3BlY3Q= 1869 +IG1hdGhlbWF0aWNz 1870 +IHJvdWdo 1871 +IG91dHNpZGU= 1872 +IGNvbWVz 1873 +YmFjaw== 1874 +IHdpbmQ= 1875 +c2Vk 1876 +IHdvdWxkbg== 1877 +ZWVy 1878 +aW51dA== 1879 +ZnJvbQ== 1880 +IHJlcGw= 1881 +IG5hcnJvdw== 1882 +IGluY2lkZW50 1883 +IGFpcg== 1884 +IHNlYQ== 1885 +dHM= 1886 +IHN1cnByaXNlZA== 1887 +IHRlYQ== 1888 +UmVk 1889 +IHRhbGtpbmc= 1890 +IGJvc3M= 1891 +cXVl 1892 +IHBpY3Q= 1893 +aXJ0eQ== 1894 +IGNl 1895 +IGxpbQ== 1896 +IFdoeQ== 1897 +IHBvaW50 1898 +IGxhdw== 1899 +Y2lhdGVk 1900 +IG1vb24= 1901 +aXJjdQ== 1902 +Z290 1903 +IElz 1904 +IGhhbmRz 1905 +IGhvbm9y 1906 +YXV0 1907 +cmdl 1908 +IHN0YXRl 1909 +IExpdGVyYXJ5 1910 +LkY= 1911 +VGhpcw== 1912 +bGluZQ== 1913 +Lmc= 1914 +Lmd1dGVuYmVyZw== 1915 +IE9G 1916 +RU4= 1917 +cmFjdGVy 1918 +IGJlbmU= 1919 +IEV2ZW4= 1920 +b3Vi 1921 +IG1ha2Vz 1922 +IGludGVyZXN0 1923 +b3Bl 1924 +bXM= 1925 +IHJlc3BvbnM= 1926 +IGZvcmU= 1927 +IHNvbWV3aGF0 1928 +IGhvbmVzdA== 1929 +b2Nr 1930 +aXJpdA== 1931 +IGhlbGQ= 1932 +IGFkZGVk 1933 +ZnU= 1934 +YWRlZA== 1935 +YWxz 1936 +YXR0 1937 +dGVybg== 1938 +IHBlcnNvbmFs 1939 +IGFzcw== 1940 +IFdpdGg= 1941 +dGlj 1942 +VG9reW8= 1943 +IHNob3V0 1944 +IHByZXR0eQ== 1945 +dW1i 1946 +IGVhcmx5 1947 +b3BwZWQ= 1948 +IGZ1cnRoZXI= 1949 +IGZyZQ== 1950 +ZXNpZGVz 1951 +IGJhbWJvbw== 1952 +IGly 1953 +bW9yZQ== 1954 +IGxpdmluZw== 1955 +IHJlY2VpdmVk 1956 +IGxpdmVk 1957 +IG1lYW50 1958 +IGNvd2FyZA== 1959 +cG9zaXRpb24= 1960 +IGxvYw== 1961 +aWxlZA== 1962 +IHRlbmRlcg== 1963 +IGNo 1964 +IEFmdGVy 1965 +Y2Vy 1966 +IGZhdm9y 1967 +d2hv 1968 +IGxpa2Vk 1969 +cmFuY2U= 1970 +IHByaQ== 1971 +a2lzaGE= 1972 +IHN0dWR5 1973 +IG9yZGVy 1974 +IGFmdGVyd2FyZA== 1975 +IGdyZWF0bHk= 1976 +IHVuYWJsZQ== 1977 +Z28= 1978 +IHdhaXQ= 1979 +ZXBpbmc= 1980 +aWRpbmc= 1981 +IGZvcnR5 1982 +IHNreQ== 1983 +IG9mZmljZQ== 1984 +d2lsbA== 1985 +IkQ= 1986 +d2Vs 1987 +IHN0YXRpb24= 1988 +Ym8= 1989 +aG90 1990 +c3VjaA== 1991 +IGxvdWQ= 1992 +IGF3 1993 +bGFuZA== 1994 +Pwo= 1995 +IHJlc3BlY3Q= 1996 +YW5jZXM= 1997 +aWVudA== 1998 +IG91Z2h0 1999 diff --git a/training/tests/assets/tiny_bpe_merges.txt b/training/tests/assets/tiny_bpe_merges.txt new file mode 100644 index 0000000..aff1dc8 --- /dev/null +++ b/training/tests/assets/tiny_bpe_merges.txt @@ -0,0 +1,1904 @@ +Ġ Ġ +Ġ t +Ġ a +i n +h e +r e +o n +Ġt he +Ġ s +e r +a t +Ġ c +ĠĠ ĠĠ +e n +Ġ o +Ġ " +n d +e s +in g +ĠĠ Ġ +i t +Ġ p +o r +o u +Ġa nd +Ġ w +i s +Ġ f +a n +i on +a l +Ġ b +Ġt o +Ġ m +Ġ in +Ġo f +l e +c t +a r +u t +Ġ d +s t +e d +ĠĠĠĠ ĠĠĠ +i c +" : +, Ċ +r o +en t +\ n +Ġ e +p ut +o m +Ġ re +a s +v e +Ġ h +Ġt h +" ,Ċ +Ġ l +Ġ is +e t +c e +Ġ n +. \ +i m +i l +Ġ g +Ġ u +ct ion +r u +at ion +o l +c h +Ġ T +Ġf or +ou t +r a +o w +i d +l y +Ġs t +Ġb e +Ġ y +Ġp ro +i g +s e +at e +Ġth at +it h +i r +u r +o t +Ġo r +Ġ on +Ġy ou +er s +st ru +Ġa n +i f +u l +stru ction +Ġ { +Ġ } +Ġc an +in put +out put +in struction +Ġ{ Ċ +Ġ} ,Ċ +" Ċ +Ġ he +Ġc on +Ġ it +a y +es s +Ġw ith +v er +e l +Ġa s +a m +Ġ A +g e +Ġs u +i v +. ",Ċ +Ġc om +Ġ I +m ent +a k +Ġa l +\ " +. "Ċ +i ve +Ġa re +a b +a d +Ġm o +Ġe x +Ġ v +Ġ S +re s +p p +q u +Ġd e +Ġw h +it y +Ġ en +ĠT he +he r +l d +r i +t er +an t +Ġ C +is t +Ġ" ",Ċ +u m +Ġu s +Ġn e +a in +t h +e ct +Ġ le +o p +e m +i es +Ġc h +Ġ im +d u +o d +or t +n t +es t +ig h +e re +Ġh a +u s +u re +i al +o c +Ġw or +Ġthe ir +a c +en ce +i z +Ġyou r +o s +Ġim p +u d +Ġb y +Ġs e +in e +ou ld +l ow +il l +a ge +ro m +Ġs p +Ġ P +Ġs h +u st +T he +u n +' s +Ġin c +id e +p l +igh t +o g +Ġp l +p t +a re +Ġt e +Ġin t +Ġ \ +h is +Ġ r +ak e +p er +or m +a g +f f +Ġ E +ar t +Ġ k +en d +Ġ M +Ġw e +Ġ B +Ġa d +c ess +r ou +ic al +al l +ab le +Ġf rom +a nd +Ġ H +Ġa b +a ct +Ġcom p +om e +a ch +ĠT his +Ġha ve +f orm +Ġ \" +a st +Ġa t +Ġ W +Ġre s +Ġd at +: \ +t her +ion s +o re +Ġ ( +Ġcon t +ou r +e p +Ġ F +Ġa c +an ce +Ġ R +g h +Ġm e +c es +Ġw as +in d +ve l +ation s +Ġhe l +Ġmo re +ul t +Ġ D +re at +ig n +Ġhel p +im e +ar d +Ġc l +Ġa pp +an s +i e +Ġdat a +ic h +an g +ou s +el l +k s +as e +ic e +i p +it e +Ġsu ch +Ġf e +Ġw he +i b +Ġo ther +Ġth is +as s +u al +i le +n e +re d +Ġh as +o o +res s +if ic +n ing +Ġ = +Ġu p +Ġm an +Ġa r +on g +e c +Ġt ra +a v +Ġwh ich +Ġg o +Ġpro v +Ġd is +* * +s o +Ġ G +on e +Ġe m +Ġn ot +u e +Ġ O +Ġ j +a ce +Ġthe y +am e +Ġ qu +Ġ L +if f +Ġf ol +ar y +at ed +ust om +it ion +Ġit s +Ġs y +k e +ac k +r y +- - +Ġt ime +Ġd es +Ġne w +ent s +ou nt +Ġfol low +Ġal so +Ġcom m +Ġo ut +Ġe ff +Ġd iff +iv en +a p +Ġs ent +\ u +Ġs o +Ġpro du +Ġu se +Ġs c +Ġ - +Ġu n +l ud +ĠI t +en er +k ing +Ġe v +Ġab out +Ġthe m +Ġ U +Ġc ustom +Ġ ro +Ġinc lud +l es +et w +st em +x t +Ġint o +Ġp er +ĠI n +Ġ N +Ġw ill +Ġle ar +b er +Ġal l +Ġp e +d s +Ġt w +ak ing +ar k +f ul +Ġm ake +ch n +er v +o st +rou gh +Ġon e +Ġin ter +it ies +a il +i ke +re e +p le +al th +Ġus ed +or s +Ġo ver +il ity +ment s +an ge +Ġw ay +or y +Ġc ol +Ġp r +Ġc ould +Ġn um +re ate +in t +Ġre du +ers on +Ġre c +Ġhe r +Ġne ed +m s +at er +o y +Ġsy stem +Ġin form +Ġtw o +Ġte chn +Ġsent ence +i ence +iz e +g et +Ġdiff ere +o od +ri b +Ġb ut +Ġfollow ing +as ed +ol og +er g +le d +u res +I n +e ar +Ġp h +ow n +Ġp re +Ġw ould +Ġus ing +Ġcon s +Ġwor k +Ġmo d +at ing +i a +i re +Ġp os +i ent +o b +j ect +Ġin v +on s +Ġd o +ul ar +Ġde c +Ġhe alth +Ġimp ro +Ġan y +Ġth rough +y p +ro w +vel op +Ġpro cess +Ġt r +l ic +ver y +al s +if y +` ` +ar i +Ġst r +Ġimp ort +Ġl ike +Ġprodu ct +Ġs ome +p h +ent ial +Ġa m +at es +Ġac c +en s +n s +Ġs m +Ġin d +e en +Ġex per +le ct +Ġv al +Ġre l +it s +Ġinform ation +ing s +Ġ J +op le +in ess +Ġg iven +m m +ic es +Ġp art +il d +y s +Ġo ur +nd er +Ġp erson +al ly +Ġk e +etw een +f t +ot h +Ġsp ec +Ġb etween +erg y +ĠA I +Ġwh o +Ġm ay +e f +at ive +is e +Ġl ist +Ġk n +Ġad d +, \ +or d +ic s +Ġpe ople +ĠS t +Ġh is +Ġex p +ib le +Ġthe re +Ġs erv +Ġinc re +Ġde velop +ou nd +ow er +Ġtr ans +b s +Ġen ergy +Ġof f +Ġb us +Ġwh ile +o se +Ġa ct +Ġex am +Ġlear ning +ction s +c on +g or +g an +ut ion +rou nd +pp ort +Ġh ow +Ġb l +Ġm ed +an c +Ġt yp +Ġ ra +Ġc ar +if e +Ġwor ld +Ġv ari +Ġre p +a u +Ġs oc +Ġprov id +Ġs et +t en +Ġs ol +Ġe ach +Ġwhe n +Ġeff ect +Ġp o +Ġs he +ic k +Ġwhe re +Ġmod el +Ġimport ant +Ġu nder +Ġpro g +ener ate +ur al +t ain +Ġas s +olog y +Ġh ad +oo k +g g +Ġcustom er +t ing +v ing +Ġres p +l ine +Ġc reat +l l +il y +Ġre g +Ġd et +Ġ if +Ġ + +Ġbus iness +\n In +is h +Ġmo st +ĠĠĠĠ ĠĠĠĠ +he s +ang u +Ġprov ide +Ġad v +er m +u b +Ġs k +ir st +an y +Ġd ay +iv id +ar m +ra ct +n ce +Ġ | +Ġimpro ve +) \ +Ġc o +Ġcomm un +ark et +Ġm et +c y +Ġdiffere nt +iz ed +Ġar t +\n The +r it +Ġcom put +Ġfor m +c k +Ġh um +Ġch ar +b le +Ġle ad +ir on +Ġre m +Ġsh ould +t e +Ġal low +n ess +h at +Ġf un +Ġcomp le +Ġl angu +ag es +Ġbe c +Ġs ign +u es +at ure +Ġf ind +ri end +Ġst ud +Ġm ain +im ate +o ve +Ġres ult +Ġpl ay +Ġredu ce +Ġen g +w are +red i +Ġnum ber +Ġl ar +Ġp ol +Ġp at +Ġw ell +id ent +v iron +r ite +c rib +Ġb u +Ġh igh +Ġthe se +iv es +v es +Ġdes ign +ur n +Ġth an +d er +Ġan al +Ġw ater +Ġm arket +Ġexam ple +w ay +st and +n g +a x +it ive +Ġ ` +i qu +Ġs im +Ġe qu +gor ith +Ġte xt +res ent +Ġman y +ur ing +-- -- +\n A +Ġd i +Ġs a +viron ment +ar ch +Ġat t +Ġp ot +Ġt as +Ġc reate +ou gh +Ġf l +Ġm aking +i ous +Ġg ra +Ġl ife +\n O +Ġal gorith +al ity +en g +Ġf in +u c +? ",Ċ +Ġ Y +Ġre t +Ġbe en +Ġtechn ology +Ġprog ra +Ġha nd +h ip +w n +Ġc al +Ġwh at +ivid ual +is s +et y +Ġlangu age +our ces +Ġcl ass +Ġt ake +Ġe as +r ic +Ġv is +b ject +Ġre f +Ġen vironment +Ġf irst +e g +Ġind ividual +Ġpl an +Ġper form +Ġ ru +i en +Ġimp act +Ġa g +ad e +Ġc le +Ġre qu +d ition +_ _ +Ġc he +pt ion +Ġapp ro +Ġ ** +Ġg reat +v ed +Ġex pl +Ġg row +G enerate +Ġm y +Ġinclud ing +Ġac cess +Ġp op +Ġm in +f ore +Ġsoc ial +in es +Ġchar act +Ġb r +Ġst ep +Ġunder stand +Ġor gan +ĠA d +Ġdis c +Ġp ower +Ġl ong +he d +Ġcon c +w ard +it ed +Ġe le +c ing +Ġe very +Ġc a +Ġof ten +Ġus er +v ie +Ġ V +Ġf ood +Ġinclud e +Ġl oc +as es +ical ly +od e +ant s +Ġinv ol +Ġsm all +Ġs ur +ach ine +Ġbe ing +Ġpot ential +Ġn o +ĠC h +Ġde p +at her +Ġb oth +Ġen s +Ġpos s +Ġ ed +crib e +t s +or k +ĠThe y +Ġp ur +iv ity +Ġwor ds +Ġsign ific +Ġw ere +ĠH ow +Ġpro m +Ġexper ience +Ġ K +u p +Ġc ount +ere d +D es +Ġf am +`` ` +ak es +Ġg l +ĠH e +Ġfe el +Ġb ack +Ġf i +Ġpro ble +iz ation +l ing +Ġcommun ic +pl oy +Ġa ut +Ġf riend +Ġhum an +Ġsp e +e w +Ġperson al +Ġto p +Ġ ent +ot her +Ġch ang +Ġc or +Ġch ange +Ġdec is +ab ility +h ing +at ural +e ver +Ġc ost +Ġgo od +au se +Ġ ident +Ġso ft +in ed +Ġp ass +' t +at ures +Ġb en +Ġcomp any +Ġst art +Ġsignific ant +Ġsu mm +on d +ol d +b ers +se l +? \ +Ġc ur +Ġl ight +Ġcomm on +.\ " +Ġcustom ers +iv ing +con om +Ġfun ction +Ġ ve +Ġth ree +Ġev en +in ing +Ġg ener +ri es +Ġle vel +Ġspec ific +Ġwe bs +Ġthe n +Ġeffect ive +c ur +en se +Ġlar ge +Ġd ist +Ġeff ic +Ġsu pport +Ġg et +C reate +re ad +p ort +Ġin f +Ġ ' +Ġy ear +Ġst ate +Ġke y +c cess +: ** +Ġa v +Ġkn ow +Ġben ef +Ġ ess +ab les +re n +Ġo wn +ĠThe se +oc k +- t +Ġ ide +om m +re en +c ed +ct ure +Ġte am +Ġr is +Ġtas ks +Ġd own +Ġst ru +Ġcomput er +- b +Ġf act +Ġm em +et ter +\n S +Ġa round +Ġwor d +Ġb ased +Ġbe h +Ġr ight +Ġd el +Ġpo int +Ġn atural +s s +Ġe conom +Ġm ade +Ġin s +Ġin st +Ġm at +Ġval ue +Ġan im +Ġse ver +\n T +ation al +it al +z e +ot e +ill s +ter n +Ġre ad +Ġcont ent +Ġon line +Ġen d +ĠU n +v ent +Ġse e +end ing +Ġm on +Ġd r +Ġke ep +Ġsystem s +c ul +v en +Ġst ory +Ġmed ia +Ġsever al +he n +ate g +Ġcont in +Ġde v +Ġlear n +Ġl a +Ġst re +Ġpart ic +Ġa ir +ual ly +Ġsu ccess +ou se +Ġis s +i ed +Ġm achine +Ġo pt +Ġ x +Ġo p +Ġpro f +oc us +ch ie +Ġmet h +n er +om p +r on +Ġh ome +Ġb etter +ĠP ro +Ġm ult +om et +Ġincre ase +Ġanal y +ver t +Ġre le +Ġb ra +in k +Ġt em +Ġp redi +Ġt re +Ġserv ice +Ġwebs ite +Ġman age +Ġsoft ware +he re +Ġpro t +- s +Ġqu est +i er +Ġkn own +Ġor der +Ġph ys +ce pt +Ġa chie +Ġin put +Ġposs ible +ĠI f +Ġex t +f ter +Ġe lect +Ġmeth od +Ġb re +ĠA n +way s +er ing +et s +Ġj ust +Ġst ore +Ġdevelop ment +Ġc are +Ġo bject +Ġtyp e +ĠF or +Ġf ocus +gg est +Ġon ly +Ġcons id +ar s +Ġch all +Ġdet erm +Ġs al +in s +Ġfe atures +Ġt ru +od y +Ġto ol +> \ +Ġens ure +os s +ub lic +Ġit em +H ere +in ation +Ġde f +Des cribe +ion al +rou p +Ġcon f +Ġneed s +Ġcharact er +Ġvari ous +Ġle t +Ġapp lic +a ut +Ġj ob +ell ig +ĠC on +Ġb est +Ġf ore +Ġam ount +ro p +Ġbu ild +iqu e +ag ing +Ġem ploy +Ġre st +a ir +W hat +Ġto get +Ġway s +Ġident ify +Ġtoget her +Ġre al +Ġus ers +Ġme an +as ing +ĠA m +Ġed uc +Ġalgorith m +Ġn etw +Ġc ode +W rite +o v +- d +ou ra +ĠHow ever +ut ure +vie w +Ġin du +Ġproduct s +ect ed +er tain +; \ +ĠA s +p r +ast e +Ġo per +Ġ $ +av i +sel f +Ġ < +Ġindu st +Ġg u +Ġother s +E x +i an +Ġ" \" +- f +n ces +Ġf il +Ġresp ons +ro l +Ġc ap +Ġbe fore +ver n +Ġcomple x +l us +rib ut +at s +Ġpos itive +o h +Ġl o +Ġg roup +Ġf ound +e e +og n +Ġs w +Ġindividual s +Ġp ract +Ġen c +Ġsh are +ra ph +Ġr ange +Ġsu n +\ t +Ġprovid ing +ic le +Ġde m +Ġpl ace +Ġa ud +j oy +Ġm ust +el s +er y +O ne +Ġfam ily +Ġf uture +l ess +re nt +Ġproble m +Ġess ential +ro du +i red +Ġredu cing +is m +Ġw arm +ra y +Ġab ility +Ġstr ong +Ġal ways +Ġres ources +Ġbenef its +Ġstr ateg +Ġinvol ves +Ġass ist +ere st +n A +ress ion +Ġ [ +il ities +Ġstep s +ver all +Ġsh ow +ob al +\n F +Ġl and +ĠH ere +Ġbusiness es +ĠE n +pport un +Ġme as +Ġret urn +Ġd ig +Ġh ist +y th +Ġc ent +Ġab le +Ġwith out +y c +pl ain +Ġrel ations +Ġserv ices +- c +Ġt est +ar th +Ġcommunic ation +Ġinter n +ne w +Ġs it +Ġinv est +Ġca us +Ġu nt +Ġfriend s +Ġchang es +c ri +d it +ĠB y +ĠY ou +Ġme ans +Ġre se +o ol +t ed +ellig ence +ain s +pp ing +Ġbe l +Ġrep resent +Ġha pp +Ġs er +Ġperform ance +Ġo pportun +Ġtem per +ĠS he +Ġf u +i x +b ot +Ġw rit +Ġbeh avi +Ġpro ject +ĠW ith +iv ers +d ay +Ġphys ical +iz ing +Ġact iv +Ġwith in +Ġint erest +ol ution +ward s +ff ic +Ġqu ick +Ġp ublic +Ġgrow th +Ġch o +Ġrelations hip +Ġunt il +Ġhelp s +Ġstud ents +Ġfi el +im es +ul ation +ib ility +el f +Ġf ul +Ġsu b +an k +id es +Ġsk ills +Ġcl imate +G iven +Ġp ar +Ġcle ar +ir t +N ame +Ġp resent +Ġt ri +Ġchall eng +re am +Ġl ay +Ġmarket ing +Ġsumm ary +Ġch ild +Ġsa f +Ġsu re +Ġs ame +Ġm u +Ġem ail +b on +Ġs omet +``` \ +Ġcur rent +am p +en ces +ĠR e +Ġtrans port +m e +- p +a ction +ĠE x +Ġyear s +Ġcom b +h or +anc ed +t y +Ġl ove +Ġg reen +Ġpop ular +Ġl ess +Ġd ra +Ġcont rol +Ġa ff +Ġcons um +Ġg ame +ent al +ight s +ar get +om es +o x +ic ult +er c +Ġgo als +anc ial +t le +Ġgo vern +Ġnum bers +Ġf ive +Ġst and +Ġse arch +Ġeffic ient +Ġw al +Ġn ame +at h +Ġhe art +Ġd uring +re ct +Ġover all +yth on +Ġallow s +Ġc ity +a ve +v ant +ater ial +Ġw ide +Ġm us +ific ial +Ġh ard +ĠT h +oo se +Ġgl obal +a j +Ġt er +Ġdiff icult +Ġl ine +ĠA l +c are +iv ed +Ġreg ular +Ġg r +) , +le ment +Ġh im +Ġun ique +Ġen joy +Ġmean ing +Ġop en +Ġ i +ab or +Ġare a +Ġitem s +Ġcle an +dition ally +o id +ĠW e +Ġbe aut +Ġme et +ip le +Ġstate ment +Ġag ain +ys is +Ġf ac +Ġs ources +Ġb ody +Ġalgorith ms +Ġaud ience +Ġw ant +Ġl og +Ġmain tain +Ġactiv ities +Ġmo ve +Ġc ult +one y +Ġt arget +\n B +Ġm aterial +Ġcreat ing +Ġstru cture +at form +e xt +Ġexper ien +Ġval ues +e ad +oh n +Ġhealth y +ro ss +Ġint eg +Ġrese arch +at ch +oo king +Ġro le +Ġprovid es +i ety +ist s +Ġfin ancial +or ies +d ent +Ġ er +Ġart icle +Ġele ments +Ġadd ress +Ġcon n +ĠU se +m p +Ġeas y +Ġne g +Ġcol or +Ġcal cul +Ex plain +ĠP l +p ect +in ce +al e +Ġris k +cur ity +er t +Ġfe ed +Ġev ent +v ers +pl es +Ġlevel s +Ġb i +Ġst ay +Ġpl atform +Ġbre ak +b ack +Ġs at +\nO verall +Ġeduc ation +\n C +Ġcar bon +---- ---- +ap e +Ġpre vent +Ġadd ition +Ġst ress +r al +our ce +ru s +Ġcom e +Ġrec ogn +ĠUn ited +Ġpro per +Ġpol l +dent ify +Ġunderstand ing +Ġdecis ions +i ct +Ġd ire +Ġbehavi or +Ġ * +\n I +Ġm ess +Ġanim als +Ġs l +Ġw ind +Ġb as +Ġp ain +Ġlead ing +er n +g er +Ġp res +Ġth ough +Ġinter act +y le +Ġdo es +Ġhe ad +Ġint elligence +ort s +Ġbec ome +Ġru n +ar ing +Ġimp lement +Ġa ction +o ot +ter ns +Ġprot ect +er ic +Ġf low +Ġem ot +cess ary +ur ate +Ġsu ggest +Ġprogra m +Ġph r +Ġhealth care +ent ion +Ġsu st +Ġwh y +Ġacc urate +l u +Ġh ig +Ġre ach +Ġallow ing +Ġtra vel +Ġrequ ire +Ġare as +Ġde ep +H e +Ġfe w +Ġs elf +ou n +Ġ # +os p +st r +Ġmin ut +Ġdecis ion +ĠThe re +an ces +Ġqu ality +Ġav ail +Ġsp ace +Ġsomet hing +Ġwe b +Ġpat terns +Ġm ot +or ing +is f +Ġan other +Ġacc ount +\n W +us s +Ġm aj +u ation +Ġsust ain +Ġaut om +iqu es +iss ions +ver se +Ġcon cept +Ġse curity +Ġth ose +Ġprof ess +Ġsh ort +Ġn ight +eng th +a pt +e x +ĠAd ditionally +Ġt aking +Ġto o +ag n +Ġsim ple +lus ion +ien cy +as h +our s +Ġp a +Ġl it +ĠS p +it ing +Ġd on +Ġl im +l ish +m at +av es +led ge +dition al +in c +Ġev ents +Ġoff er +th ing +Ġwor king +Ġanal ysis +Ġachie ve +Ġp ie +Ġb ook +Ġf re +Ġmu ch +o on +Ġt ry +es p +Ġw aste +f ace +Ġe ar +Ġf ru +Ġtransport ation +ch ool +Ġtechn iques +Ġprogra mm +ĠE arth +Ġpredi ct +Ġne ver +w s +u ment +imate ly +are d +Ġpartic ular +Ġto wards +Ġeconom ic +Ġincre asing +Ġf ast +im ent +Ġnetw ork +Ġcor rect +Ġm ight +Ġo c +Ġbec ause +ĠW h +a z +pl ay +Ġresult s +Ġmanage ment +Ġpur ch +Ġs ound +Ġp ast +Ġtra ining +__ __ +op e +Ġeng age +oura ge +Ġs ense +Ġf ree +Ġpre f +e es +Ġcount ries +ne y +an ies +Ġa fter +Ġm ind +Ġex c +ĠO nce +ĠĠĠĠ ĠĠĠĠĠĠĠ +Ġcomple te +Ġim m +Ġ est +Ġg enerate +ver b +ĠD e +' m +Ġtool s +redi ents +Ġmaj or +ent ly +Ġcont ribut +le ep +Ġpoint s +dit ions +Ġfact ors +Ġe l +Ġne xt +i um +ou d +Ġc ru +Ġre as +ri ate +ĠI nd +Ġprom ot +Ġhist ory +Ġj our +Ġd ue +C on +Ġve get +en cy +ĠAm eric +Ġf ra +Ġdiffere nce +o ard +le x +Ġequ ation +irt ual +Ġc up +Ġfore st +Ġneg ative +Ġse con +on es +Ġn ature +Ġus es +a h +p or +Ġse c +ord ing +Ġl ast +ĠS ome +Ġiss ues +Ġsc ient +Ġpr int +ĠSt ates +o ver +Ġsat isf +Ġdev ices +Ġdis e +Ġtemper ature +Ġfeed back +Ġne cessary +Ġem issions +m b +Ġl ow +f or +t al +Ġchalleng es +Ġar ray +Ġs ide +Ġeng ine +Ġb oo +at a +Ġbel ie +- m +Ġmult iple +Ġs ing +Ġgovern ment +am es +if ied +Ġminut es +Ġsuccess ful +Ġm oney +Ġquick ly +Ġb ir +Ġtyp ically +Ġp ost +Ġpre p +Ġknow ledge +pp ed +a ctions +Ġmethod s +Ġopt im +\n P +Ġout put +Ġfiel d +Ġt able +Ġb al +Ġcol l +Ġcharact ers +v olution +or ds +il ar +ific ation +an e +Ġc ell +Ġm il +ĠW hat +Ġs qu +Ġl ives +ĠA r +Ġphr ase +Ġn ut +Ġdig ital +Ġintern et +l ass +u ra +omm end +Ġt reat +Ġappro p +res h +ur ther +ĠO ne +Ġvis ual +ate gor +Ġappro ach +Ġc ertain +Ġsh o +v al +Ġtas k +i res +Ġapprop riate +Ġv ie +Ġdesign ed +p ose +** : +f ort +Ġ| \ +Ġapplic ations +Ġp ay +Ġn ow +Ġhe at +Ġindust ry +p re +Ġeffective ly +Ġpop ulation +Ġopportun ities +< / +ĠT o +Ġup d +Ġinclud es +ĠE ng +Ġtyp es +Ġup on +Ġconsid er +le t +Ġg en +og raph +pl ace +Ġt imes +Ġar g +C omp +ĠG o +Ġre ce +Ġchild ren +Ġtra ck +Ġsome one +w ord +Ġyou ng +Ġcon ditions +Ġtra ditional +Ġmodel s +I dentify +Ġc amp +Ġm akes +ist ic +Ġar r +Ġc ard +ut ions +l t +Ġo ld +Ġide as +Ġe y +Ġt ree +Ġiss ue +Ġh arm +Ġavail able +Ġc r +Ġpower ful +n ov +Ġmo vie +Ġwe ather +Ġsk y +Ġquest ions +e et +Ġact ivity +Ġbra nd +is hed +Ġanaly ze +ĠS h +Ġen h +av or +Ġbe g +Ġs chool +i ate +Ġeas ier +Ġinf lu +Ġn on +Ġstud y +Ġl ook +Ġsol ution +Ġle g +Ġcon st +H ow +Ġcomp et diff --git a/training/tests/assets/tiny_bpe_tokenizer.json b/training/tests/assets/tiny_bpe_tokenizer.json new file mode 100644 index 0000000..b8c5258 --- /dev/null +++ b/training/tests/assets/tiny_bpe_tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":2000,"content":"<|endoftext|>","single_word":false,"lstrip":false,"rstrip":false,"normalized":false,"special":true},{"id":2001,"content":"<|im_start|>","single_word":false,"lstrip":false,"rstrip":false,"normalized":false,"special":true},{"id":2002,"content":"<|im_end|>","single_word":false,"lstrip":false,"rstrip":false,"normalized":false,"special":true}],"normalizer":{"type":"NFC"},"pre_tokenizer":{"type":"Sequence","pretokenizers":[{"type":"Split","pattern":{"Regex":"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"},"behavior":"Isolated","invert":false},{"type":"ByteLevel","add_prefix_space":false,"trim_offsets":false,"use_regex":false}]},"post_processor":{"type":"ByteLevel","add_prefix_space":false,"trim_offsets":false,"use_regex":false},"decoder":{"type":"ByteLevel","add_prefix_space":false,"trim_offsets":false,"use_regex":false},"model":{"type":"BPE","dropout":null,"unk_token":null,"continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":false,"vocab":{"0":15,"1":16,"2":17,"3":18,"4":19,"5":20,"6":21,"7":22,"8":23,"9":24,"!":0,"\"":1,"#":2,"$":3,"%":4,"&":5,"'":6,"(":7,")":8,"*":9,"+":10,",":11,"-":12,".":13,"/":14,":":25,";":26,"<":27,"=":28,">":29,"?":30,"@":31,"A":32,"B":33,"C":34,"D":35,"E":36,"F":37,"G":38,"H":39,"I":40,"J":41,"K":42,"L":43,"M":44,"N":45,"O":46,"P":47,"Q":48,"R":49,"S":50,"T":51,"U":52,"V":53,"W":54,"X":55,"Y":56,"Z":57,"[":58,"\\":59,"]":60,"^":61,"_":62,"`":63,"a":64,"b":65,"c":66,"d":67,"e":68,"f":69,"g":70,"h":71,"i":72,"j":73,"k":74,"l":75,"m":76,"n":77,"o":78,"p":79,"q":80,"r":81,"s":82,"t":83,"u":84,"v":85,"w":86,"x":87,"y":88,"z":89,"{":90,"|":91,"}":92,"~":93,"Ċ":94,"Ġ":95,"ĠĠ":96,"Ġt":97,"Ġa":98,"in":99,"he":100,"re":101,"on":102,"Ġthe":103,"Ġs":104,"er":105,"at":106,"Ġc":107,"ĠĠĠĠ":108,"en":109,"Ġo":110,"Ġ\"":111,"nd":112,"es":113,"ing":114,"ĠĠĠ":115,"it":116,"Ġp":117,"or":118,"ou":119,"Ġand":120,"Ġw":121,"is":122,"Ġf":123,"an":124,"ion":125,"al":126,"Ġb":127,"Ġto":128,"Ġm":129,"Ġin":130,"Ġof":131,"le":132,"ct":133,"ar":134,"ut":135,"Ġd":136,"st":137,"ed":138,"ĠĠĠĠĠĠĠ":139,"ic":140,"\":":141,",Ċ":142,"ro":143,"ent":144,"\\n":145,"Ġe":146,"put":147,"om":148,"Ġre":149,"as":150,"ve":151,"Ġh":152,"Ġth":153,"\",Ċ":154,"Ġl":155,"Ġis":156,"et":157,"ce":158,"Ġn":159,".\\":160,"im":161,"il":162,"Ġg":163,"Ġu":164,"ction":165,"ru":166,"ation":167,"ol":168,"ch":169,"ĠT":170,"Ġfor":171,"out":172,"ra":173,"ow":174,"id":175,"ly":176,"Ġst":177,"Ġbe":178,"Ġy":179,"Ġpro":180,"ig":181,"se":182,"ate":183,"Ġthat":184,"ith":185,"ir":186,"ur":187,"ot":188,"Ġor":189,"Ġon":190,"Ġyou":191,"ers":192,"stru":193,"Ġan":194,"if":195,"ul":196,"struction":197,"Ġ{":198,"Ġ}":199,"Ġcan":200,"input":201,"output":202,"instruction":203,"Ġ{Ċ":204,"Ġ},Ċ":205,"\"Ċ":206,"Ġhe":207,"Ġcon":208,"Ġit":209,"ay":210,"ess":211,"Ġwith":212,"ver":213,"el":214,"Ġas":215,"am":216,"ĠA":217,"ge":218,"Ġsu":219,"iv":220,".\",Ċ":221,"Ġcom":222,"ĠI":223,"ment":224,"ak":225,"Ġal":226,"\\\"":227,".\"Ċ":228,"ive":229,"Ġare":230,"ab":231,"ad":232,"Ġmo":233,"Ġex":234,"Ġv":235,"ĠS":236,"res":237,"pp":238,"qu":239,"Ġde":240,"Ġwh":241,"ity":242,"Ġen":243,"ĠThe":244,"her":245,"ld":246,"ri":247,"ter":248,"ant":249,"ĠC":250,"ist":251,"Ġ\"\",Ċ":252,"um":253,"Ġus":254,"Ġne":255,"ain":256,"th":257,"ect":258,"Ġle":259,"op":260,"em":261,"ies":262,"Ġch":263,"Ġim":264,"du":265,"od":266,"ort":267,"nt":268,"est":269,"igh":270,"ere":271,"Ġha":272,"us":273,"ure":274,"ial":275,"oc":276,"Ġwor":277,"Ġtheir":278,"ac":279,"ence":280,"iz":281,"Ġyour":282,"os":283,"Ġimp":284,"ud":285,"Ġby":286,"Ġse":287,"ine":288,"ould":289,"low":290,"ill":291,"age":292,"rom":293,"Ġsp":294,"ĠP":295,"Ġsh":296,"ust":297,"The":298,"un":299,"'s":300,"Ġinc":301,"ide":302,"pl":303,"ight":304,"og":305,"Ġpl":306,"pt":307,"are":308,"Ġte":309,"Ġint":310,"Ġ\\":311,"his":312,"Ġr":313,"ake":314,"per":315,"orm":316,"ag":317,"ff":318,"ĠE":319,"art":320,"Ġk":321,"end":322,"ĠM":323,"Ġwe":324,"ĠB":325,"Ġad":326,"cess":327,"rou":328,"ical":329,"all":330,"able":331,"Ġfrom":332,"and":333,"ĠH":334,"Ġab":335,"act":336,"Ġcomp":337,"ome":338,"ach":339,"ĠThis":340,"Ġhave":341,"form":342,"Ġ\\\"":343,"ast":344,"Ġat":345,"ĠW":346,"Ġres":347,"Ġdat":348,":\\":349,"ther":350,"ions":351,"ore":352,"Ġ(":353,"Ġcont":354,"our":355,"ep":356,"ĠF":357,"Ġac":358,"ance":359,"ĠR":360,"gh":361,"Ġme":362,"ces":363,"Ġwas":364,"ind":365,"vel":366,"ations":367,"Ġhel":368,"Ġmore":369,"ult":370,"ĠD":371,"reat":372,"ign":373,"Ġhelp":374,"ime":375,"ard":376,"Ġcl":377,"Ġapp":378,"ans":379,"ie":380,"Ġdata":381,"ich":382,"ang":383,"ous":384,"ell":385,"ks":386,"ase":387,"ice":388,"ip":389,"ite":390,"Ġsuch":391,"Ġfe":392,"Ġwhe":393,"ib":394,"Ġother":395,"Ġthis":396,"ass":397,"ual":398,"ile":399,"ne":400,"red":401,"Ġhas":402,"oo":403,"ress":404,"ific":405,"ning":406,"Ġ=":407,"Ġup":408,"Ġman":409,"Ġar":410,"ong":411,"ec":412,"Ġtra":413,"av":414,"Ġwhich":415,"Ġgo":416,"Ġprov":417,"Ġdis":418,"**":419,"so":420,"ĠG":421,"one":422,"Ġem":423,"Ġnot":424,"ue":425,"ĠO":426,"Ġj":427,"ace":428,"Ġthey":429,"ame":430,"Ġqu":431,"ĠL":432,"iff":433,"Ġfol":434,"ary":435,"ated":436,"ustom":437,"ition":438,"Ġits":439,"Ġsy":440,"ke":441,"ack":442,"ry":443,"--":444,"Ġtime":445,"Ġdes":446,"Ġnew":447,"ents":448,"ount":449,"Ġfollow":450,"Ġalso":451,"Ġcomm":452,"Ġout":453,"Ġeff":454,"Ġdiff":455,"iven":456,"ap":457,"Ġsent":458,"\\u":459,"Ġso":460,"Ġprodu":461,"Ġuse":462,"Ġsc":463,"Ġ-":464,"Ġun":465,"lud":466,"ĠIt":467,"ener":468,"king":469,"Ġev":470,"Ġabout":471,"Ġthem":472,"ĠU":473,"Ġcustom":474,"Ġro":475,"Ġinclud":476,"les":477,"etw":478,"stem":479,"xt":480,"Ġinto":481,"Ġper":482,"ĠIn":483,"ĠN":484,"Ġwill":485,"Ġlear":486,"ber":487,"Ġall":488,"Ġpe":489,"ds":490,"Ġtw":491,"aking":492,"ark":493,"ful":494,"Ġmake":495,"chn":496,"erv":497,"ost":498,"rough":499,"Ġone":500,"Ġinter":501,"ities":502,"ail":503,"ike":504,"ree":505,"ple":506,"alth":507,"Ġused":508,"ors":509,"Ġover":510,"ility":511,"ments":512,"ange":513,"Ġway":514,"ory":515,"Ġcol":516,"Ġpr":517,"Ġcould":518,"Ġnum":519,"reate":520,"int":521,"Ġredu":522,"erson":523,"Ġrec":524,"Ġher":525,"Ġneed":526,"ms":527,"ater":528,"oy":529,"Ġsystem":530,"Ġinform":531,"Ġtwo":532,"Ġtechn":533,"Ġsentence":534,"ience":535,"ize":536,"get":537,"Ġdiffere":538,"ood":539,"rib":540,"Ġbut":541,"Ġfollowing":542,"ased":543,"olog":544,"erg":545,"led":546,"ures":547,"In":548,"ear":549,"Ġph":550,"own":551,"Ġpre":552,"Ġwould":553,"Ġusing":554,"Ġcons":555,"Ġwork":556,"Ġmod":557,"ating":558,"ia":559,"ire":560,"Ġpos":561,"ient":562,"ob":563,"ject":564,"Ġinv":565,"ons":566,"Ġdo":567,"ular":568,"Ġdec":569,"Ġhealth":570,"Ġimpro":571,"Ġany":572,"Ġthrough":573,"yp":574,"row":575,"velop":576,"Ġprocess":577,"Ġtr":578,"lic":579,"very":580,"als":581,"ify":582,"``":583,"ari":584,"Ġstr":585,"Ġimport":586,"Ġlike":587,"Ġproduct":588,"Ġsome":589,"ph":590,"ential":591,"Ġam":592,"ates":593,"Ġacc":594,"ens":595,"ns":596,"Ġsm":597,"Ġind":598,"een":599,"Ġexper":600,"lect":601,"Ġval":602,"Ġrel":603,"its":604,"Ġinformation":605,"ings":606,"ĠJ":607,"ople":608,"iness":609,"Ġgiven":610,"mm":611,"ices":612,"Ġpart":613,"ild":614,"ys":615,"Ġour":616,"nder":617,"Ġperson":618,"ally":619,"Ġke":620,"etween":621,"ft":622,"oth":623,"Ġspec":624,"Ġbetween":625,"ergy":626,"ĠAI":627,"Ġwho":628,"Ġmay":629,"ef":630,"ative":631,"ise":632,"Ġlist":633,"Ġkn":634,"Ġadd":635,",\\":636,"ord":637,"ics":638,"Ġpeople":639,"ĠSt":640,"Ġhis":641,"Ġexp":642,"ible":643,"Ġthere":644,"Ġserv":645,"Ġincre":646,"Ġdevelop":647,"ound":648,"ower":649,"Ġtrans":650,"bs":651,"Ġenergy":652,"Ġoff":653,"Ġbus":654,"Ġwhile":655,"ose":656,"Ġact":657,"Ġexam":658,"Ġlearning":659,"ctions":660,"con":661,"gor":662,"gan":663,"ution":664,"round":665,"pport":666,"Ġhow":667,"Ġbl":668,"Ġmed":669,"anc":670,"Ġtyp":671,"Ġra":672,"Ġcar":673,"ife":674,"Ġworld":675,"Ġvari":676,"Ġrep":677,"au":678,"Ġsoc":679,"Ġprovid":680,"Ġset":681,"ten":682,"Ġsol":683,"Ġeach":684,"Ġwhen":685,"Ġeffect":686,"Ġpo":687,"Ġshe":688,"ick":689,"Ġwhere":690,"Ġmodel":691,"Ġimportant":692,"Ġunder":693,"Ġprog":694,"enerate":695,"ural":696,"tain":697,"Ġass":698,"ology":699,"Ġhad":700,"ook":701,"gg":702,"Ġcustomer":703,"ting":704,"ving":705,"Ġresp":706,"line":707,"Ġcreat":708,"ll":709,"ily":710,"Ġreg":711,"Ġdet":712,"Ġif":713,"Ġ+":714,"Ġbusiness":715,"\\nIn":716,"ish":717,"Ġmost":718,"ĠĠĠĠĠĠĠĠ":719,"hes":720,"angu":721,"Ġprovide":722,"Ġadv":723,"erm":724,"ub":725,"Ġsk":726,"irst":727,"any":728,"Ġday":729,"ivid":730,"arm":731,"ract":732,"nce":733,"Ġ|":734,"Ġimprove":735,")\\":736,"Ġco":737,"Ġcommun":738,"arket":739,"Ġmet":740,"cy":741,"Ġdifferent":742,"ized":743,"Ġart":744,"\\nThe":745,"rit":746,"Ġcomput":747,"Ġform":748,"ck":749,"Ġhum":750,"Ġchar":751,"ble":752,"Ġlead":753,"iron":754,"Ġrem":755,"Ġshould":756,"te":757,"Ġallow":758,"ness":759,"hat":760,"Ġfun":761,"Ġcomple":762,"Ġlangu":763,"ages":764,"Ġbec":765,"Ġsign":766,"ues":767,"ature":768,"Ġfind":769,"riend":770,"Ġstud":771,"Ġmain":772,"imate":773,"ove":774,"Ġresult":775,"Ġplay":776,"Ġreduce":777,"Ġeng":778,"ware":779,"redi":780,"Ġnumber":781,"Ġlar":782,"Ġpol":783,"Ġpat":784,"Ġwell":785,"ident":786,"viron":787,"rite":788,"crib":789,"Ġbu":790,"Ġhigh":791,"Ġthese":792,"ives":793,"ves":794,"Ġdesign":795,"urn":796,"Ġthan":797,"der":798,"Ġanal":799,"Ġwater":800,"Ġmarket":801,"Ġexample":802,"way":803,"stand":804,"ng":805,"ax":806,"itive":807,"Ġ`":808,"iqu":809,"Ġsim":810,"Ġequ":811,"gorith":812,"Ġtext":813,"resent":814,"Ġmany":815,"uring":816,"----":817,"\\nA":818,"Ġdi":819,"Ġsa":820,"vironment":821,"arch":822,"Ġatt":823,"Ġpot":824,"Ġtas":825,"Ġcreate":826,"ough":827,"Ġfl":828,"Ġmaking":829,"ious":830,"Ġgra":831,"Ġlife":832,"\\nO":833,"Ġalgorith":834,"ality":835,"eng":836,"Ġfin":837,"uc":838,"?\",Ċ":839,"ĠY":840,"Ġret":841,"Ġbeen":842,"Ġtechnology":843,"Ġprogra":844,"Ġhand":845,"hip":846,"wn":847,"Ġcal":848,"Ġwhat":849,"ividual":850,"iss":851,"ety":852,"Ġlanguage":853,"ources":854,"Ġclass":855,"Ġtake":856,"Ġeas":857,"ric":858,"Ġvis":859,"bject":860,"Ġref":861,"Ġenvironment":862,"Ġfirst":863,"eg":864,"Ġindividual":865,"Ġplan":866,"Ġperform":867,"Ġru":868,"ien":869,"Ġimpact":870,"Ġag":871,"ade":872,"Ġcle":873,"Ġrequ":874,"dition":875,"__":876,"Ġche":877,"ption":878,"Ġappro":879,"Ġ**":880,"Ġgreat":881,"ved":882,"Ġexpl":883,"Ġgrow":884,"Generate":885,"Ġmy":886,"Ġincluding":887,"Ġaccess":888,"Ġpop":889,"Ġmin":890,"fore":891,"Ġsocial":892,"ines":893,"Ġcharact":894,"Ġbr":895,"Ġstep":896,"Ġunderstand":897,"Ġorgan":898,"ĠAd":899,"Ġdisc":900,"Ġpower":901,"Ġlong":902,"hed":903,"Ġconc":904,"ward":905,"ited":906,"Ġele":907,"cing":908,"Ġevery":909,"Ġca":910,"Ġoften":911,"Ġuser":912,"vie":913,"ĠV":914,"Ġfood":915,"Ġinclude":916,"Ġloc":917,"ases":918,"ically":919,"ode":920,"ants":921,"Ġinvol":922,"Ġsmall":923,"Ġsur":924,"achine":925,"Ġbeing":926,"Ġpotential":927,"Ġno":928,"ĠCh":929,"Ġdep":930,"ather":931,"Ġboth":932,"Ġens":933,"Ġposs":934,"Ġed":935,"cribe":936,"ts":937,"ork":938,"ĠThey":939,"Ġpur":940,"ivity":941,"Ġwords":942,"Ġsignific":943,"Ġwere":944,"ĠHow":945,"Ġprom":946,"Ġexperience":947,"ĠK":948,"up":949,"Ġcount":950,"ered":951,"Des":952,"Ġfam":953,"```":954,"akes":955,"Ġgl":956,"ĠHe":957,"Ġfeel":958,"Ġback":959,"Ġfi":960,"Ġproble":961,"ization":962,"ling":963,"Ġcommunic":964,"ploy":965,"Ġaut":966,"Ġfriend":967,"Ġhuman":968,"Ġspe":969,"ew":970,"Ġpersonal":971,"Ġtop":972,"Ġent":973,"other":974,"Ġchang":975,"Ġcor":976,"Ġchange":977,"Ġdecis":978,"ability":979,"hing":980,"atural":981,"ever":982,"Ġcost":983,"Ġgood":984,"ause":985,"Ġident":986,"Ġsoft":987,"ined":988,"Ġpass":989,"'t":990,"atures":991,"Ġben":992,"Ġcompany":993,"Ġstart":994,"Ġsignificant":995,"Ġsumm":996,"ond":997,"old":998,"bers":999,"sel":1000,"?\\":1001,"Ġcur":1002,"Ġlight":1003,"Ġcommon":1004,".\\\"":1005,"Ġcustomers":1006,"iving":1007,"conom":1008,"Ġfunction":1009,"Ġve":1010,"Ġthree":1011,"Ġeven":1012,"ining":1013,"Ġgener":1014,"ries":1015,"Ġlevel":1016,"Ġspecific":1017,"Ġwebs":1018,"Ġthen":1019,"Ġeffective":1020,"cur":1021,"ense":1022,"Ġlarge":1023,"Ġdist":1024,"Ġeffic":1025,"Ġsupport":1026,"Ġget":1027,"Create":1028,"read":1029,"port":1030,"Ġinf":1031,"Ġ'":1032,"Ġyear":1033,"Ġstate":1034,"Ġkey":1035,"ccess":1036,":**":1037,"Ġav":1038,"Ġknow":1039,"Ġbenef":1040,"Ġess":1041,"ables":1042,"ren":1043,"Ġown":1044,"ĠThese":1045,"ock":1046,"-t":1047,"Ġide":1048,"omm":1049,"reen":1050,"ced":1051,"cture":1052,"Ġteam":1053,"Ġris":1054,"Ġtasks":1055,"Ġdown":1056,"Ġstru":1057,"Ġcomputer":1058,"-b":1059,"Ġfact":1060,"Ġmem":1061,"etter":1062,"\\nS":1063,"Ġaround":1064,"Ġword":1065,"Ġbased":1066,"Ġbeh":1067,"Ġright":1068,"Ġdel":1069,"Ġpoint":1070,"Ġnatural":1071,"ss":1072,"Ġeconom":1073,"Ġmade":1074,"Ġins":1075,"Ġinst":1076,"Ġmat":1077,"Ġvalue":1078,"Ġanim":1079,"Ġsever":1080,"\\nT":1081,"ational":1082,"ital":1083,"ze":1084,"ote":1085,"ills":1086,"tern":1087,"Ġread":1088,"Ġcontent":1089,"Ġonline":1090,"Ġend":1091,"ĠUn":1092,"vent":1093,"Ġsee":1094,"ending":1095,"Ġmon":1096,"Ġdr":1097,"Ġkeep":1098,"Ġsystems":1099,"cul":1100,"ven":1101,"Ġstory":1102,"Ġmedia":1103,"Ġseveral":1104,"hen":1105,"ateg":1106,"Ġcontin":1107,"Ġdev":1108,"Ġlearn":1109,"Ġla":1110,"Ġstre":1111,"Ġpartic":1112,"Ġair":1113,"ually":1114,"Ġsuccess":1115,"ouse":1116,"Ġiss":1117,"ied":1118,"Ġmachine":1119,"Ġopt":1120,"Ġx":1121,"Ġop":1122,"Ġprof":1123,"ocus":1124,"chie":1125,"Ġmeth":1126,"ner":1127,"omp":1128,"ron":1129,"Ġhome":1130,"Ġbetter":1131,"ĠPro":1132,"Ġmult":1133,"omet":1134,"Ġincrease":1135,"Ġanaly":1136,"vert":1137,"Ġrele":1138,"Ġbra":1139,"ink":1140,"Ġtem":1141,"Ġpredi":1142,"Ġtre":1143,"Ġservice":1144,"Ġwebsite":1145,"Ġmanage":1146,"Ġsoftware":1147,"here":1148,"Ġprot":1149,"-s":1150,"Ġquest":1151,"ier":1152,"Ġknown":1153,"Ġorder":1154,"Ġphys":1155,"cept":1156,"Ġachie":1157,"Ġinput":1158,"Ġpossible":1159,"ĠIf":1160,"Ġext":1161,"fter":1162,"Ġelect":1163,"Ġmethod":1164,"Ġbre":1165,"ĠAn":1166,"ways":1167,"ering":1168,"ets":1169,"Ġjust":1170,"Ġstore":1171,"Ġdevelopment":1172,"Ġcare":1173,"Ġobject":1174,"Ġtype":1175,"ĠFor":1176,"Ġfocus":1177,"ggest":1178,"Ġonly":1179,"Ġconsid":1180,"ars":1181,"Ġchall":1182,"Ġdeterm":1183,"Ġsal":1184,"ins":1185,"Ġfeatures":1186,"Ġtru":1187,"ody":1188,"Ġtool":1189,">\\":1190,"Ġensure":1191,"oss":1192,"ublic":1193,"Ġitem":1194,"Here":1195,"ination":1196,"Ġdef":1197,"Describe":1198,"ional":1199,"roup":1200,"Ġconf":1201,"Ġneeds":1202,"Ġcharacter":1203,"Ġvarious":1204,"Ġlet":1205,"Ġapplic":1206,"aut":1207,"Ġjob":1208,"ellig":1209,"ĠCon":1210,"Ġbest":1211,"Ġfore":1212,"Ġamount":1213,"rop":1214,"Ġbuild":1215,"ique":1216,"aging":1217,"Ġemploy":1218,"Ġrest":1219,"air":1220,"What":1221,"Ġtoget":1222,"Ġways":1223,"Ġidentify":1224,"Ġtogether":1225,"Ġreal":1226,"Ġusers":1227,"Ġmean":1228,"asing":1229,"ĠAm":1230,"Ġeduc":1231,"Ġalgorithm":1232,"Ġnetw":1233,"Ġcode":1234,"Write":1235,"ov":1236,"-d":1237,"oura":1238,"ĠHowever":1239,"uture":1240,"view":1241,"Ġindu":1242,"Ġproducts":1243,"ected":1244,"ertain":1245,";\\":1246,"ĠAs":1247,"pr":1248,"aste":1249,"Ġoper":1250,"Ġ$":1251,"avi":1252,"self":1253,"Ġ<":1254,"Ġindust":1255,"Ġgu":1256,"Ġothers":1257,"Ex":1258,"ian":1259,"Ġ\"\\\"":1260,"-f":1261,"nces":1262,"Ġfil":1263,"Ġrespons":1264,"rol":1265,"Ġcap":1266,"Ġbefore":1267,"vern":1268,"Ġcomplex":1269,"lus":1270,"ribut":1271,"ats":1272,"Ġpositive":1273,"oh":1274,"Ġlo":1275,"Ġgroup":1276,"Ġfound":1277,"ee":1278,"ogn":1279,"Ġsw":1280,"Ġindividuals":1281,"Ġpract":1282,"Ġenc":1283,"Ġshare":1284,"raph":1285,"Ġrange":1286,"Ġsun":1287,"\\t":1288,"Ġproviding":1289,"icle":1290,"Ġdem":1291,"Ġplace":1292,"Ġaud":1293,"joy":1294,"Ġmust":1295,"els":1296,"ery":1297,"One":1298,"Ġfamily":1299,"Ġfuture":1300,"less":1301,"rent":1302,"Ġproblem":1303,"Ġessential":1304,"rodu":1305,"ired":1306,"Ġreducing":1307,"ism":1308,"Ġwarm":1309,"ray":1310,"Ġability":1311,"Ġstrong":1312,"Ġalways":1313,"Ġresources":1314,"Ġbenefits":1315,"Ġstrateg":1316,"Ġinvolves":1317,"Ġassist":1318,"erest":1319,"nA":1320,"ression":1321,"Ġ[":1322,"ilities":1323,"Ġsteps":1324,"verall":1325,"Ġshow":1326,"obal":1327,"\\nF":1328,"Ġland":1329,"ĠHere":1330,"Ġbusinesses":1331,"ĠEn":1332,"pportun":1333,"Ġmeas":1334,"Ġreturn":1335,"Ġdig":1336,"Ġhist":1337,"yth":1338,"Ġcent":1339,"Ġable":1340,"Ġwithout":1341,"yc":1342,"plain":1343,"Ġrelations":1344,"Ġservices":1345,"-c":1346,"Ġtest":1347,"arth":1348,"Ġcommunication":1349,"Ġintern":1350,"new":1351,"Ġsit":1352,"Ġinvest":1353,"Ġcaus":1354,"Ġunt":1355,"Ġfriends":1356,"Ġchanges":1357,"cri":1358,"dit":1359,"ĠBy":1360,"ĠYou":1361,"Ġmeans":1362,"Ġrese":1363,"ool":1364,"ted":1365,"elligence":1366,"ains":1367,"pping":1368,"Ġbel":1369,"Ġrepresent":1370,"Ġhapp":1371,"Ġser":1372,"Ġperformance":1373,"Ġopportun":1374,"Ġtemper":1375,"ĠShe":1376,"Ġfu":1377,"ix":1378,"bot":1379,"Ġwrit":1380,"Ġbehavi":1381,"Ġproject":1382,"ĠWith":1383,"ivers":1384,"day":1385,"Ġphysical":1386,"izing":1387,"Ġactiv":1388,"Ġwithin":1389,"Ġinterest":1390,"olution":1391,"wards":1392,"ffic":1393,"Ġquick":1394,"Ġpublic":1395,"Ġgrowth":1396,"Ġcho":1397,"Ġrelationship":1398,"Ġuntil":1399,"Ġhelps":1400,"Ġstudents":1401,"Ġfiel":1402,"imes":1403,"ulation":1404,"ibility":1405,"elf":1406,"Ġful":1407,"Ġsub":1408,"ank":1409,"ides":1410,"Ġskills":1411,"Ġclimate":1412,"Given":1413,"Ġpar":1414,"Ġclear":1415,"irt":1416,"Name":1417,"Ġpresent":1418,"Ġtri":1419,"Ġchalleng":1420,"ream":1421,"Ġlay":1422,"Ġmarketing":1423,"Ġsummary":1424,"Ġchild":1425,"Ġsaf":1426,"Ġsure":1427,"Ġsame":1428,"Ġmu":1429,"Ġemail":1430,"bon":1431,"Ġsomet":1432,"```\\":1433,"Ġcurrent":1434,"amp":1435,"ences":1436,"ĠRe":1437,"Ġtransport":1438,"me":1439,"-p":1440,"action":1441,"ĠEx":1442,"Ġyears":1443,"Ġcomb":1444,"hor":1445,"anced":1446,"ty":1447,"Ġlove":1448,"Ġgreen":1449,"Ġpopular":1450,"Ġless":1451,"Ġdra":1452,"Ġcontrol":1453,"Ġaff":1454,"Ġconsum":1455,"Ġgame":1456,"ental":1457,"ights":1458,"arget":1459,"omes":1460,"ox":1461,"icult":1462,"erc":1463,"Ġgoals":1464,"ancial":1465,"tle":1466,"Ġgovern":1467,"Ġnumbers":1468,"Ġfive":1469,"Ġstand":1470,"Ġsearch":1471,"Ġefficient":1472,"Ġwal":1473,"Ġname":1474,"ath":1475,"Ġheart":1476,"Ġduring":1477,"rect":1478,"Ġoverall":1479,"ython":1480,"Ġallows":1481,"Ġcity":1482,"ave":1483,"vant":1484,"aterial":1485,"Ġwide":1486,"Ġmus":1487,"ificial":1488,"Ġhard":1489,"ĠTh":1490,"oose":1491,"Ġglobal":1492,"aj":1493,"Ġter":1494,"Ġdifficult":1495,"Ġline":1496,"ĠAl":1497,"care":1498,"ived":1499,"Ġregular":1500,"Ġgr":1501,"),":1502,"lement":1503,"Ġhim":1504,"Ġunique":1505,"Ġenjoy":1506,"Ġmeaning":1507,"Ġopen":1508,"Ġi":1509,"abor":1510,"Ġarea":1511,"Ġitems":1512,"Ġclean":1513,"ditionally":1514,"oid":1515,"ĠWe":1516,"Ġbeaut":1517,"Ġmeet":1518,"iple":1519,"Ġstatement":1520,"Ġagain":1521,"ysis":1522,"Ġfac":1523,"Ġsources":1524,"Ġbody":1525,"Ġalgorithms":1526,"Ġaudience":1527,"Ġwant":1528,"Ġlog":1529,"Ġmaintain":1530,"Ġactivities":1531,"Ġmove":1532,"Ġcult":1533,"oney":1534,"Ġtarget":1535,"\\nB":1536,"Ġmaterial":1537,"Ġcreating":1538,"Ġstructure":1539,"atform":1540,"ext":1541,"Ġexperien":1542,"Ġvalues":1543,"ead":1544,"ohn":1545,"Ġhealthy":1546,"ross":1547,"Ġinteg":1548,"Ġresearch":1549,"atch":1550,"ooking":1551,"Ġrole":1552,"Ġprovides":1553,"iety":1554,"ists":1555,"Ġfinancial":1556,"ories":1557,"dent":1558,"Ġer":1559,"Ġarticle":1560,"Ġelements":1561,"Ġaddress":1562,"Ġconn":1563,"ĠUse":1564,"mp":1565,"Ġeasy":1566,"Ġneg":1567,"Ġcolor":1568,"Ġcalcul":1569,"Explain":1570,"ĠPl":1571,"pect":1572,"ince":1573,"ale":1574,"Ġrisk":1575,"curity":1576,"ert":1577,"Ġfeed":1578,"Ġevent":1579,"vers":1580,"ples":1581,"Ġlevels":1582,"Ġbi":1583,"Ġstay":1584,"Ġplatform":1585,"Ġbreak":1586,"back":1587,"Ġsat":1588,"\\nOverall":1589,"Ġeducation":1590,"\\nC":1591,"Ġcarbon":1592,"--------":1593,"ape":1594,"Ġprevent":1595,"Ġaddition":1596,"Ġstress":1597,"ral":1598,"ource":1599,"rus":1600,"Ġcome":1601,"Ġrecogn":1602,"ĠUnited":1603,"Ġproper":1604,"Ġpoll":1605,"dentify":1606,"Ġunderstanding":1607,"Ġdecisions":1608,"ict":1609,"Ġdire":1610,"Ġbehavior":1611,"Ġ*":1612,"\\nI":1613,"Ġmess":1614,"Ġanimals":1615,"Ġsl":1616,"Ġwind":1617,"Ġbas":1618,"Ġpain":1619,"Ġleading":1620,"ern":1621,"ger":1622,"Ġpres":1623,"Ġthough":1624,"Ġinteract":1625,"yle":1626,"Ġdoes":1627,"Ġhead":1628,"Ġintelligence":1629,"orts":1630,"Ġbecome":1631,"Ġrun":1632,"aring":1633,"Ġimplement":1634,"Ġaction":1635,"oot":1636,"terns":1637,"Ġprotect":1638,"eric":1639,"Ġflow":1640,"Ġemot":1641,"cessary":1642,"urate":1643,"Ġsuggest":1644,"Ġprogram":1645,"Ġphr":1646,"Ġhealthcare":1647,"ention":1648,"Ġsust":1649,"Ġwhy":1650,"Ġaccurate":1651,"lu":1652,"Ġhig":1653,"Ġreach":1654,"Ġallowing":1655,"Ġtravel":1656,"Ġrequire":1657,"Ġareas":1658,"Ġdeep":1659,"He":1660,"Ġfew":1661,"Ġself":1662,"oun":1663,"Ġ#":1664,"osp":1665,"str":1666,"Ġminut":1667,"Ġdecision":1668,"ĠThere":1669,"ances":1670,"Ġquality":1671,"Ġavail":1672,"Ġspace":1673,"Ġsomething":1674,"Ġweb":1675,"Ġpatterns":1676,"Ġmot":1677,"oring":1678,"isf":1679,"Ġanother":1680,"Ġaccount":1681,"\\nW":1682,"uss":1683,"Ġmaj":1684,"uation":1685,"Ġsustain":1686,"Ġautom":1687,"iques":1688,"issions":1689,"verse":1690,"Ġconcept":1691,"Ġsecurity":1692,"Ġthose":1693,"Ġprofess":1694,"Ġshort":1695,"Ġnight":1696,"ength":1697,"apt":1698,"ex":1699,"ĠAdditionally":1700,"Ġtaking":1701,"Ġtoo":1702,"agn":1703,"Ġsimple":1704,"lusion":1705,"iency":1706,"ash":1707,"ours":1708,"Ġpa":1709,"Ġlit":1710,"ĠSp":1711,"iting":1712,"Ġdon":1713,"Ġlim":1714,"lish":1715,"mat":1716,"aves":1717,"ledge":1718,"ditional":1719,"inc":1720,"Ġevents":1721,"Ġoffer":1722,"thing":1723,"Ġworking":1724,"Ġanalysis":1725,"Ġachieve":1726,"Ġpie":1727,"Ġbook":1728,"Ġfre":1729,"Ġmuch":1730,"oon":1731,"Ġtry":1732,"esp":1733,"Ġwaste":1734,"face":1735,"Ġear":1736,"Ġfru":1737,"Ġtransportation":1738,"chool":1739,"Ġtechniques":1740,"Ġprogramm":1741,"ĠEarth":1742,"Ġpredict":1743,"Ġnever":1744,"ws":1745,"ument":1746,"imately":1747,"ared":1748,"Ġparticular":1749,"Ġtowards":1750,"Ġeconomic":1751,"Ġincreasing":1752,"Ġfast":1753,"iment":1754,"Ġnetwork":1755,"Ġcorrect":1756,"Ġmight":1757,"Ġoc":1758,"Ġbecause":1759,"ĠWh":1760,"az":1761,"play":1762,"Ġresults":1763,"Ġmanagement":1764,"Ġpurch":1765,"Ġsound":1766,"Ġpast":1767,"Ġtraining":1768,"____":1769,"ope":1770,"Ġengage":1771,"ourage":1772,"Ġsense":1773,"Ġfree":1774,"Ġpref":1775,"ees":1776,"Ġcountries":1777,"ney":1778,"anies":1779,"Ġafter":1780,"Ġmind":1781,"Ġexc":1782,"ĠOnce":1783,"ĠĠĠĠĠĠĠĠĠĠĠ":1784,"Ġcomplete":1785,"Ġimm":1786,"Ġest":1787,"Ġgenerate":1788,"verb":1789,"ĠDe":1790,"'m":1791,"Ġtools":1792,"redients":1793,"Ġmajor":1794,"ently":1795,"Ġcontribut":1796,"leep":1797,"Ġpoints":1798,"ditions":1799,"Ġfactors":1800,"Ġel":1801,"Ġnext":1802,"ium":1803,"oud":1804,"Ġcru":1805,"Ġreas":1806,"riate":1807,"ĠInd":1808,"Ġpromot":1809,"Ġhistory":1810,"Ġjour":1811,"Ġdue":1812,"Con":1813,"Ġveget":1814,"ency":1815,"ĠAmeric":1816,"Ġfra":1817,"Ġdifference":1818,"oard":1819,"lex":1820,"Ġequation":1821,"irtual":1822,"Ġcup":1823,"Ġforest":1824,"Ġnegative":1825,"Ġsecon":1826,"ones":1827,"Ġnature":1828,"Ġuses":1829,"ah":1830,"por":1831,"Ġsec":1832,"ording":1833,"Ġlast":1834,"ĠSome":1835,"Ġissues":1836,"Ġscient":1837,"Ġprint":1838,"ĠStates":1839,"over":1840,"Ġsatisf":1841,"Ġdevices":1842,"Ġdise":1843,"Ġtemperature":1844,"Ġfeedback":1845,"Ġnecessary":1846,"Ġemissions":1847,"mb":1848,"Ġlow":1849,"for":1850,"tal":1851,"Ġchallenges":1852,"Ġarray":1853,"Ġside":1854,"Ġengine":1855,"Ġboo":1856,"ata":1857,"Ġbelie":1858,"-m":1859,"Ġmultiple":1860,"Ġsing":1861,"Ġgovernment":1862,"ames":1863,"ified":1864,"Ġminutes":1865,"Ġsuccessful":1866,"Ġmoney":1867,"Ġquickly":1868,"Ġbir":1869,"Ġtypically":1870,"Ġpost":1871,"Ġprep":1872,"Ġknowledge":1873,"pped":1874,"actions":1875,"Ġmethods":1876,"Ġoptim":1877,"\\nP":1878,"Ġoutput":1879,"Ġfield":1880,"Ġtable":1881,"Ġbal":1882,"Ġcoll":1883,"Ġcharacters":1884,"volution":1885,"ords":1886,"ilar":1887,"ification":1888,"ane":1889,"Ġcell":1890,"Ġmil":1891,"ĠWhat":1892,"Ġsqu":1893,"Ġlives":1894,"ĠAr":1895,"Ġphrase":1896,"Ġnut":1897,"Ġdigital":1898,"Ġinternet":1899,"lass":1900,"ura":1901,"ommend":1902,"Ġtreat":1903,"Ġapprop":1904,"resh":1905,"urther":1906,"ĠOne":1907,"Ġvisual":1908,"ategor":1909,"Ġapproach":1910,"Ġcertain":1911,"Ġsho":1912,"val":1913,"Ġtask":1914,"ires":1915,"Ġappropriate":1916,"Ġvie":1917,"Ġdesigned":1918,"pose":1919,"**:":1920,"fort":1921,"Ġ|\\":1922,"Ġapplications":1923,"Ġpay":1924,"Ġnow":1925,"Ġheat":1926,"Ġindustry":1927,"pre":1928,"Ġeffectively":1929,"Ġpopulation":1930,"Ġopportunities":1931," \\","Ġens ure","os s","ub lic","Ġit em","H ere","in ation","Ġde f","Des cribe","ion al","rou p","Ġcon f","Ġneed s","Ġcharact er","Ġvari ous","Ġle t","Ġapp lic","a ut","Ġj ob","ell ig","ĠC on","Ġb est","Ġf ore","Ġam ount","ro p","Ġbu ild","iqu e","ag ing","Ġem ploy","Ġre st","a ir","W hat","Ġto get","Ġway s","Ġident ify","Ġtoget her","Ġre al","Ġus ers","Ġme an","as ing","ĠA m","Ġed uc","Ġalgorith m","Ġn etw","Ġc ode","W rite","o v","- d","ou ra","ĠHow ever","ut ure","vie w","Ġin du","Ġproduct s","ect ed","er tain","; \\","ĠA s","p r","ast e","Ġo per","Ġ $","av i","sel f","Ġ <","Ġindu st","Ġg u","Ġother s","E x","i an","Ġ\" \\\"","- f","n ces","Ġf il","Ġresp ons","ro l","Ġc ap","Ġbe fore","ver n","Ġcomple x","l us","rib ut","at s","Ġpos itive","o h","Ġl o","Ġg roup","Ġf ound","e e","og n","Ġs w","Ġindividual s","Ġp ract","Ġen c","Ġsh are","ra ph","Ġr ange","Ġsu n","\\ t","Ġprovid ing","ic le","Ġde m","Ġpl ace","Ġa ud","j oy","Ġm ust","el s","er y","O ne","Ġfam ily","Ġf uture","l ess","re nt","Ġproble m","Ġess ential","ro du","i red","Ġredu cing","is m","Ġw arm","ra y","Ġab ility","Ġstr ong","Ġal ways","Ġres ources","Ġbenef its","Ġstr ateg","Ġinvol ves","Ġass ist","ere st","n A","ress ion","Ġ [","il ities","Ġstep s","ver all","Ġsh ow","ob al","\\n F","Ġl and","ĠH ere","Ġbusiness es","ĠE n","pport un","Ġme as","Ġret urn","Ġd ig","Ġh ist","y th","Ġc ent","Ġab le","Ġwith out","y c","pl ain","Ġrel ations","Ġserv ices","- c","Ġt est","ar th","Ġcommunic ation","Ġinter n","ne w","Ġs it","Ġinv est","Ġca us","Ġu nt","Ġfriend s","Ġchang es","c ri","d it","ĠB y","ĠY ou","Ġme ans","Ġre se","o ol","t ed","ellig ence","ain s","pp ing","Ġbe l","Ġrep resent","Ġha pp","Ġs er","Ġperform ance","Ġo pportun","Ġtem per","ĠS he","Ġf u","i x","b ot","Ġw rit","Ġbeh avi","Ġpro ject","ĠW ith","iv ers","d ay","Ġphys ical","iz ing","Ġact iv","Ġwith in","Ġint erest","ol ution","ward s","ff ic","Ġqu ick","Ġp ublic","Ġgrow th","Ġch o","Ġrelations hip","Ġunt il","Ġhelp s","Ġstud ents","Ġfi el","im es","ul ation","ib ility","el f","Ġf ul","Ġsu b","an k","id es","Ġsk ills","Ġcl imate","G iven","Ġp ar","Ġcle ar","ir t","N ame","Ġp resent","Ġt ri","Ġchall eng","re am","Ġl ay","Ġmarket ing","Ġsumm ary","Ġch ild","Ġsa f","Ġsu re","Ġs ame","Ġm u","Ġem ail","b on","Ġs omet","``` \\","Ġcur rent","am p","en ces","ĠR e","Ġtrans port","m e","- p","a ction","ĠE x","Ġyear s","Ġcom b","h or","anc ed","t y","Ġl ove","Ġg reen","Ġpop ular","Ġl ess","Ġd ra","Ġcont rol","Ġa ff","Ġcons um","Ġg ame","ent al","ight s","ar get","om es","o x","ic ult","er c","Ġgo als","anc ial","t le","Ġgo vern","Ġnum bers","Ġf ive","Ġst and","Ġse arch","Ġeffic ient","Ġw al","Ġn ame","at h","Ġhe art","Ġd uring","re ct","Ġover all","yth on","Ġallow s","Ġc ity","a ve","v ant","ater ial","Ġw ide","Ġm us","ific ial","Ġh ard","ĠT h","oo se","Ġgl obal","a j","Ġt er","Ġdiff icult","Ġl ine","ĠA l","c are","iv ed","Ġreg ular","Ġg r",") ,","le ment","Ġh im","Ġun ique","Ġen joy","Ġmean ing","Ġop en","Ġ i","ab or","Ġare a","Ġitem s","Ġcle an","dition ally","o id","ĠW e","Ġbe aut","Ġme et","ip le","Ġstate ment","Ġag ain","ys is","Ġf ac","Ġs ources","Ġb ody","Ġalgorith ms","Ġaud ience","Ġw ant","Ġl og","Ġmain tain","Ġactiv ities","Ġmo ve","Ġc ult","one y","Ġt arget","\\n B","Ġm aterial","Ġcreat ing","Ġstru cture","at form","e xt","Ġexper ien","Ġval ues","e ad","oh n","Ġhealth y","ro ss","Ġint eg","Ġrese arch","at ch","oo king","Ġro le","Ġprovid es","i ety","ist s","Ġfin ancial","or ies","d ent","Ġ er","Ġart icle","Ġele ments","Ġadd ress","Ġcon n","ĠU se","m p","Ġeas y","Ġne g","Ġcol or","Ġcal cul","Ex plain","ĠP l","p ect","in ce","al e","Ġris k","cur ity","er t","Ġfe ed","Ġev ent","v ers","pl es","Ġlevel s","Ġb i","Ġst ay","Ġpl atform","Ġbre ak","b ack","Ġs at","\\nO verall","Ġeduc ation","\\n C","Ġcar bon","---- ----","ap e","Ġpre vent","Ġadd ition","Ġst ress","r al","our ce","ru s","Ġcom e","Ġrec ogn","ĠUn ited","Ġpro per","Ġpol l","dent ify","Ġunderstand ing","Ġdecis ions","i ct","Ġd ire","Ġbehavi or","Ġ *","\\n I","Ġm ess","Ġanim als","Ġs l","Ġw ind","Ġb as","Ġp ain","Ġlead ing","er n","g er","Ġp res","Ġth ough","Ġinter act","y le","Ġdo es","Ġhe ad","Ġint elligence","ort s","Ġbec ome","Ġru n","ar ing","Ġimp lement","Ġa ction","o ot","ter ns","Ġprot ect","er ic","Ġf low","Ġem ot","cess ary","ur ate","Ġsu ggest","Ġprogra m","Ġph r","Ġhealth care","ent ion","Ġsu st","Ġwh y","Ġacc urate","l u","Ġh ig","Ġre ach","Ġallow ing","Ġtra vel","Ġrequ ire","Ġare as","Ġde ep","H e","Ġfe w","Ġs elf","ou n","Ġ #","os p","st r","Ġmin ut","Ġdecis ion","ĠThe re","an ces","Ġqu ality","Ġav ail","Ġsp ace","Ġsomet hing","Ġwe b","Ġpat terns","Ġm ot","or ing","is f","Ġan other","Ġacc ount","\\n W","us s","Ġm aj","u ation","Ġsust ain","Ġaut om","iqu es","iss ions","ver se","Ġcon cept","Ġse curity","Ġth ose","Ġprof ess","Ġsh ort","Ġn ight","eng th","a pt","e x","ĠAd ditionally","Ġt aking","Ġto o","ag n","Ġsim ple","lus ion","ien cy","as h","our s","Ġp a","Ġl it","ĠS p","it ing","Ġd on","Ġl im","l ish","m at","av es","led ge","dition al","in c","Ġev ents","Ġoff er","th ing","Ġwor king","Ġanal ysis","Ġachie ve","Ġp ie","Ġb ook","Ġf re","Ġmu ch","o on","Ġt ry","es p","Ġw aste","f ace","Ġe ar","Ġf ru","Ġtransport ation","ch ool","Ġtechn iques","Ġprogra mm","ĠE arth","Ġpredi ct","Ġne ver","w s","u ment","imate ly","are d","Ġpartic ular","Ġto wards","Ġeconom ic","Ġincre asing","Ġf ast","im ent","Ġnetw ork","Ġcor rect","Ġm ight","Ġo c","Ġbec ause","ĠW h","a z","pl ay","Ġresult s","Ġmanage ment","Ġpur ch","Ġs ound","Ġp ast","Ġtra ining","__ __","op e","Ġeng age","oura ge","Ġs ense","Ġf ree","Ġpre f","e es","Ġcount ries","ne y","an ies","Ġa fter","Ġm ind","Ġex c","ĠO nce","ĠĠĠĠ ĠĠĠĠĠĠĠ","Ġcomple te","Ġim m","Ġ est","Ġg enerate","ver b","ĠD e","' m","Ġtool s","redi ents","Ġmaj or","ent ly","Ġcont ribut","le ep","Ġpoint s","dit ions","Ġfact ors","Ġe l","Ġne xt","i um","ou d","Ġc ru","Ġre as","ri ate","ĠI nd","Ġprom ot","Ġhist ory","Ġj our","Ġd ue","C on","Ġve get","en cy","ĠAm eric","Ġf ra","Ġdiffere nce","o ard","le x","Ġequ ation","irt ual","Ġc up","Ġfore st","Ġneg ative","Ġse con","on es","Ġn ature","Ġus es","a h","p or","Ġse c","ord ing","Ġl ast","ĠS ome","Ġiss ues","Ġsc ient","Ġpr int","ĠSt ates","o ver","Ġsat isf","Ġdev ices","Ġdis e","Ġtemper ature","Ġfeed back","Ġne cessary","Ġem issions","m b","Ġl ow","f or","t al","Ġchalleng es","Ġar ray","Ġs ide","Ġeng ine","Ġb oo","at a","Ġbel ie","- m","Ġmult iple","Ġs ing","Ġgovern ment","am es","if ied","Ġminut es","Ġsuccess ful","Ġm oney","Ġquick ly","Ġb ir","Ġtyp ically","Ġp ost","Ġpre p","Ġknow ledge","pp ed","a ctions","Ġmethod s","Ġopt im","\\n P","Ġout put","Ġfiel d","Ġt able","Ġb al","Ġcol l","Ġcharact ers","v olution","or ds","il ar","ific ation","an e","Ġc ell","Ġm il","ĠW hat","Ġs qu","Ġl ives","ĠA r","Ġphr ase","Ġn ut","Ġdig ital","Ġintern et","l ass","u ra","omm end","Ġt reat","Ġappro p","res h","ur ther","ĠO ne","Ġvis ual","ate gor","Ġappro ach","Ġc ertain","Ġsh o","v al","Ġtas k","i res","Ġapprop riate","Ġv ie","Ġdesign ed","p ose","** :","f ort","Ġ| \\","Ġapplic ations","Ġp ay","Ġn ow","Ġhe at","Ġindust ry","p re","Ġeffective ly","Ġpop ulation","Ġopportun ities","< /","ĠT o","Ġup d","Ġinclud es","ĠE ng","Ġtyp es","Ġup on","Ġconsid er","le t","Ġg en","og raph","pl ace","Ġt imes","Ġar g","C omp","ĠG o","Ġre ce","Ġchild ren","Ġtra ck","Ġsome one","w ord","Ġyou ng","Ġcon ditions","Ġtra ditional","Ġmodel s","I dentify","Ġc amp","Ġm akes","ist ic","Ġar r","Ġc ard","ut ions","l t","Ġo ld","Ġide as","Ġe y","Ġt ree","Ġiss ue","Ġh arm","Ġavail able","Ġc r","Ġpower ful","n ov","Ġmo vie","Ġwe ather","Ġsk y","Ġquest ions","e et","Ġact ivity","Ġbra nd","is hed","Ġanaly ze","ĠS h","Ġen h","av or","Ġbe g","Ġs chool","i ate","Ġeas ier","Ġinf lu","Ġn on","Ġstud y","Ġl ook","Ġsol ution","Ġle g","Ġcon st","H ow","Ġcomp et"]}} diff --git a/training/tests/assets/tiny_bpe_vocab.json b/training/tests/assets/tiny_bpe_vocab.json new file mode 100644 index 0000000..a29ae94 --- /dev/null +++ b/training/tests/assets/tiny_bpe_vocab.json @@ -0,0 +1 @@ +{"0":15,"1":16,"2":17,"3":18,"4":19,"5":20,"6":21,"7":22,"8":23,"9":24,"!":0,"\"":1,"#":2,"$":3,"%":4,"&":5,"'":6,"(":7,")":8,"*":9,"+":10,",":11,"-":12,".":13,"/":14,":":25,";":26,"<":27,"=":28,">":29,"?":30,"@":31,"A":32,"B":33,"C":34,"D":35,"E":36,"F":37,"G":38,"H":39,"I":40,"J":41,"K":42,"L":43,"M":44,"N":45,"O":46,"P":47,"Q":48,"R":49,"S":50,"T":51,"U":52,"V":53,"W":54,"X":55,"Y":56,"Z":57,"[":58,"\\":59,"]":60,"^":61,"_":62,"`":63,"a":64,"b":65,"c":66,"d":67,"e":68,"f":69,"g":70,"h":71,"i":72,"j":73,"k":74,"l":75,"m":76,"n":77,"o":78,"p":79,"q":80,"r":81,"s":82,"t":83,"u":84,"v":85,"w":86,"x":87,"y":88,"z":89,"{":90,"|":91,"}":92,"~":93,"Ċ":94,"Ġ":95,"ĠĠ":96,"Ġt":97,"Ġa":98,"in":99,"he":100,"re":101,"on":102,"Ġthe":103,"Ġs":104,"er":105,"at":106,"Ġc":107,"ĠĠĠĠ":108,"en":109,"Ġo":110,"Ġ\"":111,"nd":112,"es":113,"ing":114,"ĠĠĠ":115,"it":116,"Ġp":117,"or":118,"ou":119,"Ġand":120,"Ġw":121,"is":122,"Ġf":123,"an":124,"ion":125,"al":126,"Ġb":127,"Ġto":128,"Ġm":129,"Ġin":130,"Ġof":131,"le":132,"ct":133,"ar":134,"ut":135,"Ġd":136,"st":137,"ed":138,"ĠĠĠĠĠĠĠ":139,"ic":140,"\":":141,",Ċ":142,"ro":143,"ent":144,"\\n":145,"Ġe":146,"put":147,"om":148,"Ġre":149,"as":150,"ve":151,"Ġh":152,"Ġth":153,"\",Ċ":154,"Ġl":155,"Ġis":156,"et":157,"ce":158,"Ġn":159,".\\":160,"im":161,"il":162,"Ġg":163,"Ġu":164,"ction":165,"ru":166,"ation":167,"ol":168,"ch":169,"ĠT":170,"Ġfor":171,"out":172,"ra":173,"ow":174,"id":175,"ly":176,"Ġst":177,"Ġbe":178,"Ġy":179,"Ġpro":180,"ig":181,"se":182,"ate":183,"Ġthat":184,"ith":185,"ir":186,"ur":187,"ot":188,"Ġor":189,"Ġon":190,"Ġyou":191,"ers":192,"stru":193,"Ġan":194,"if":195,"ul":196,"struction":197,"Ġ{":198,"Ġ}":199,"Ġcan":200,"input":201,"output":202,"instruction":203,"Ġ{Ċ":204,"Ġ},Ċ":205,"\"Ċ":206,"Ġhe":207,"Ġcon":208,"Ġit":209,"ay":210,"ess":211,"Ġwith":212,"ver":213,"el":214,"Ġas":215,"am":216,"ĠA":217,"ge":218,"Ġsu":219,"iv":220,".\",Ċ":221,"Ġcom":222,"ĠI":223,"ment":224,"ak":225,"Ġal":226,"\\\"":227,".\"Ċ":228,"ive":229,"Ġare":230,"ab":231,"ad":232,"Ġmo":233,"Ġex":234,"Ġv":235,"ĠS":236,"res":237,"pp":238,"qu":239,"Ġde":240,"Ġwh":241,"ity":242,"Ġen":243,"ĠThe":244,"her":245,"ld":246,"ri":247,"ter":248,"ant":249,"ĠC":250,"ist":251,"Ġ\"\",Ċ":252,"um":253,"Ġus":254,"Ġne":255,"ain":256,"th":257,"ect":258,"Ġle":259,"op":260,"em":261,"ies":262,"Ġch":263,"Ġim":264,"du":265,"od":266,"ort":267,"nt":268,"est":269,"igh":270,"ere":271,"Ġha":272,"us":273,"ure":274,"ial":275,"oc":276,"Ġwor":277,"Ġtheir":278,"ac":279,"ence":280,"iz":281,"Ġyour":282,"os":283,"Ġimp":284,"ud":285,"Ġby":286,"Ġse":287,"ine":288,"ould":289,"low":290,"ill":291,"age":292,"rom":293,"Ġsp":294,"ĠP":295,"Ġsh":296,"ust":297,"The":298,"un":299,"'s":300,"Ġinc":301,"ide":302,"pl":303,"ight":304,"og":305,"Ġpl":306,"pt":307,"are":308,"Ġte":309,"Ġint":310,"Ġ\\":311,"his":312,"Ġr":313,"ake":314,"per":315,"orm":316,"ag":317,"ff":318,"ĠE":319,"art":320,"Ġk":321,"end":322,"ĠM":323,"Ġwe":324,"ĠB":325,"Ġad":326,"cess":327,"rou":328,"ical":329,"all":330,"able":331,"Ġfrom":332,"and":333,"ĠH":334,"Ġab":335,"act":336,"Ġcomp":337,"ome":338,"ach":339,"ĠThis":340,"Ġhave":341,"form":342,"Ġ\\\"":343,"ast":344,"Ġat":345,"ĠW":346,"Ġres":347,"Ġdat":348,":\\":349,"ther":350,"ions":351,"ore":352,"Ġ(":353,"Ġcont":354,"our":355,"ep":356,"ĠF":357,"Ġac":358,"ance":359,"ĠR":360,"gh":361,"Ġme":362,"ces":363,"Ġwas":364,"ind":365,"vel":366,"ations":367,"Ġhel":368,"Ġmore":369,"ult":370,"ĠD":371,"reat":372,"ign":373,"Ġhelp":374,"ime":375,"ard":376,"Ġcl":377,"Ġapp":378,"ans":379,"ie":380,"Ġdata":381,"ich":382,"ang":383,"ous":384,"ell":385,"ks":386,"ase":387,"ice":388,"ip":389,"ite":390,"Ġsuch":391,"Ġfe":392,"Ġwhe":393,"ib":394,"Ġother":395,"Ġthis":396,"ass":397,"ual":398,"ile":399,"ne":400,"red":401,"Ġhas":402,"oo":403,"ress":404,"ific":405,"ning":406,"Ġ=":407,"Ġup":408,"Ġman":409,"Ġar":410,"ong":411,"ec":412,"Ġtra":413,"av":414,"Ġwhich":415,"Ġgo":416,"Ġprov":417,"Ġdis":418,"**":419,"so":420,"ĠG":421,"one":422,"Ġem":423,"Ġnot":424,"ue":425,"ĠO":426,"Ġj":427,"ace":428,"Ġthey":429,"ame":430,"Ġqu":431,"ĠL":432,"iff":433,"Ġfol":434,"ary":435,"ated":436,"ustom":437,"ition":438,"Ġits":439,"Ġsy":440,"ke":441,"ack":442,"ry":443,"--":444,"Ġtime":445,"Ġdes":446,"Ġnew":447,"ents":448,"ount":449,"Ġfollow":450,"Ġalso":451,"Ġcomm":452,"Ġout":453,"Ġeff":454,"Ġdiff":455,"iven":456,"ap":457,"Ġsent":458,"\\u":459,"Ġso":460,"Ġprodu":461,"Ġuse":462,"Ġsc":463,"Ġ-":464,"Ġun":465,"lud":466,"ĠIt":467,"ener":468,"king":469,"Ġev":470,"Ġabout":471,"Ġthem":472,"ĠU":473,"Ġcustom":474,"Ġro":475,"Ġinclud":476,"les":477,"etw":478,"stem":479,"xt":480,"Ġinto":481,"Ġper":482,"ĠIn":483,"ĠN":484,"Ġwill":485,"Ġlear":486,"ber":487,"Ġall":488,"Ġpe":489,"ds":490,"Ġtw":491,"aking":492,"ark":493,"ful":494,"Ġmake":495,"chn":496,"erv":497,"ost":498,"rough":499,"Ġone":500,"Ġinter":501,"ities":502,"ail":503,"ike":504,"ree":505,"ple":506,"alth":507,"Ġused":508,"ors":509,"Ġover":510,"ility":511,"ments":512,"ange":513,"Ġway":514,"ory":515,"Ġcol":516,"Ġpr":517,"Ġcould":518,"Ġnum":519,"reate":520,"int":521,"Ġredu":522,"erson":523,"Ġrec":524,"Ġher":525,"Ġneed":526,"ms":527,"ater":528,"oy":529,"Ġsystem":530,"Ġinform":531,"Ġtwo":532,"Ġtechn":533,"Ġsentence":534,"ience":535,"ize":536,"get":537,"Ġdiffere":538,"ood":539,"rib":540,"Ġbut":541,"Ġfollowing":542,"ased":543,"olog":544,"erg":545,"led":546,"ures":547,"In":548,"ear":549,"Ġph":550,"own":551,"Ġpre":552,"Ġwould":553,"Ġusing":554,"Ġcons":555,"Ġwork":556,"Ġmod":557,"ating":558,"ia":559,"ire":560,"Ġpos":561,"ient":562,"ob":563,"ject":564,"Ġinv":565,"ons":566,"Ġdo":567,"ular":568,"Ġdec":569,"Ġhealth":570,"Ġimpro":571,"Ġany":572,"Ġthrough":573,"yp":574,"row":575,"velop":576,"Ġprocess":577,"Ġtr":578,"lic":579,"very":580,"als":581,"ify":582,"``":583,"ari":584,"Ġstr":585,"Ġimport":586,"Ġlike":587,"Ġproduct":588,"Ġsome":589,"ph":590,"ential":591,"Ġam":592,"ates":593,"Ġacc":594,"ens":595,"ns":596,"Ġsm":597,"Ġind":598,"een":599,"Ġexper":600,"lect":601,"Ġval":602,"Ġrel":603,"its":604,"Ġinformation":605,"ings":606,"ĠJ":607,"ople":608,"iness":609,"Ġgiven":610,"mm":611,"ices":612,"Ġpart":613,"ild":614,"ys":615,"Ġour":616,"nder":617,"Ġperson":618,"ally":619,"Ġke":620,"etween":621,"ft":622,"oth":623,"Ġspec":624,"Ġbetween":625,"ergy":626,"ĠAI":627,"Ġwho":628,"Ġmay":629,"ef":630,"ative":631,"ise":632,"Ġlist":633,"Ġkn":634,"Ġadd":635,",\\":636,"ord":637,"ics":638,"Ġpeople":639,"ĠSt":640,"Ġhis":641,"Ġexp":642,"ible":643,"Ġthere":644,"Ġserv":645,"Ġincre":646,"Ġdevelop":647,"ound":648,"ower":649,"Ġtrans":650,"bs":651,"Ġenergy":652,"Ġoff":653,"Ġbus":654,"Ġwhile":655,"ose":656,"Ġact":657,"Ġexam":658,"Ġlearning":659,"ctions":660,"con":661,"gor":662,"gan":663,"ution":664,"round":665,"pport":666,"Ġhow":667,"Ġbl":668,"Ġmed":669,"anc":670,"Ġtyp":671,"Ġra":672,"Ġcar":673,"ife":674,"Ġworld":675,"Ġvari":676,"Ġrep":677,"au":678,"Ġsoc":679,"Ġprovid":680,"Ġset":681,"ten":682,"Ġsol":683,"Ġeach":684,"Ġwhen":685,"Ġeffect":686,"Ġpo":687,"Ġshe":688,"ick":689,"Ġwhere":690,"Ġmodel":691,"Ġimportant":692,"Ġunder":693,"Ġprog":694,"enerate":695,"ural":696,"tain":697,"Ġass":698,"ology":699,"Ġhad":700,"ook":701,"gg":702,"Ġcustomer":703,"ting":704,"ving":705,"Ġresp":706,"line":707,"Ġcreat":708,"ll":709,"ily":710,"Ġreg":711,"Ġdet":712,"Ġif":713,"Ġ+":714,"Ġbusiness":715,"\\nIn":716,"ish":717,"Ġmost":718,"ĠĠĠĠĠĠĠĠ":719,"hes":720,"angu":721,"Ġprovide":722,"Ġadv":723,"erm":724,"ub":725,"Ġsk":726,"irst":727,"any":728,"Ġday":729,"ivid":730,"arm":731,"ract":732,"nce":733,"Ġ|":734,"Ġimprove":735,")\\":736,"Ġco":737,"Ġcommun":738,"arket":739,"Ġmet":740,"cy":741,"Ġdifferent":742,"ized":743,"Ġart":744,"\\nThe":745,"rit":746,"Ġcomput":747,"Ġform":748,"ck":749,"Ġhum":750,"Ġchar":751,"ble":752,"Ġlead":753,"iron":754,"Ġrem":755,"Ġshould":756,"te":757,"Ġallow":758,"ness":759,"hat":760,"Ġfun":761,"Ġcomple":762,"Ġlangu":763,"ages":764,"Ġbec":765,"Ġsign":766,"ues":767,"ature":768,"Ġfind":769,"riend":770,"Ġstud":771,"Ġmain":772,"imate":773,"ove":774,"Ġresult":775,"Ġplay":776,"Ġreduce":777,"Ġeng":778,"ware":779,"redi":780,"Ġnumber":781,"Ġlar":782,"Ġpol":783,"Ġpat":784,"Ġwell":785,"ident":786,"viron":787,"rite":788,"crib":789,"Ġbu":790,"Ġhigh":791,"Ġthese":792,"ives":793,"ves":794,"Ġdesign":795,"urn":796,"Ġthan":797,"der":798,"Ġanal":799,"Ġwater":800,"Ġmarket":801,"Ġexample":802,"way":803,"stand":804,"ng":805,"ax":806,"itive":807,"Ġ`":808,"iqu":809,"Ġsim":810,"Ġequ":811,"gorith":812,"Ġtext":813,"resent":814,"Ġmany":815,"uring":816,"----":817,"\\nA":818,"Ġdi":819,"Ġsa":820,"vironment":821,"arch":822,"Ġatt":823,"Ġpot":824,"Ġtas":825,"Ġcreate":826,"ough":827,"Ġfl":828,"Ġmaking":829,"ious":830,"Ġgra":831,"Ġlife":832,"\\nO":833,"Ġalgorith":834,"ality":835,"eng":836,"Ġfin":837,"uc":838,"?\",Ċ":839,"ĠY":840,"Ġret":841,"Ġbeen":842,"Ġtechnology":843,"Ġprogra":844,"Ġhand":845,"hip":846,"wn":847,"Ġcal":848,"Ġwhat":849,"ividual":850,"iss":851,"ety":852,"Ġlanguage":853,"ources":854,"Ġclass":855,"Ġtake":856,"Ġeas":857,"ric":858,"Ġvis":859,"bject":860,"Ġref":861,"Ġenvironment":862,"Ġfirst":863,"eg":864,"Ġindividual":865,"Ġplan":866,"Ġperform":867,"Ġru":868,"ien":869,"Ġimpact":870,"Ġag":871,"ade":872,"Ġcle":873,"Ġrequ":874,"dition":875,"__":876,"Ġche":877,"ption":878,"Ġappro":879,"Ġ**":880,"Ġgreat":881,"ved":882,"Ġexpl":883,"Ġgrow":884,"Generate":885,"Ġmy":886,"Ġincluding":887,"Ġaccess":888,"Ġpop":889,"Ġmin":890,"fore":891,"Ġsocial":892,"ines":893,"Ġcharact":894,"Ġbr":895,"Ġstep":896,"Ġunderstand":897,"Ġorgan":898,"ĠAd":899,"Ġdisc":900,"Ġpower":901,"Ġlong":902,"hed":903,"Ġconc":904,"ward":905,"ited":906,"Ġele":907,"cing":908,"Ġevery":909,"Ġca":910,"Ġoften":911,"Ġuser":912,"vie":913,"ĠV":914,"Ġfood":915,"Ġinclude":916,"Ġloc":917,"ases":918,"ically":919,"ode":920,"ants":921,"Ġinvol":922,"Ġsmall":923,"Ġsur":924,"achine":925,"Ġbeing":926,"Ġpotential":927,"Ġno":928,"ĠCh":929,"Ġdep":930,"ather":931,"Ġboth":932,"Ġens":933,"Ġposs":934,"Ġed":935,"cribe":936,"ts":937,"ork":938,"ĠThey":939,"Ġpur":940,"ivity":941,"Ġwords":942,"Ġsignific":943,"Ġwere":944,"ĠHow":945,"Ġprom":946,"Ġexperience":947,"ĠK":948,"up":949,"Ġcount":950,"ered":951,"Des":952,"Ġfam":953,"```":954,"akes":955,"Ġgl":956,"ĠHe":957,"Ġfeel":958,"Ġback":959,"Ġfi":960,"Ġproble":961,"ization":962,"ling":963,"Ġcommunic":964,"ploy":965,"Ġaut":966,"Ġfriend":967,"Ġhuman":968,"Ġspe":969,"ew":970,"Ġpersonal":971,"Ġtop":972,"Ġent":973,"other":974,"Ġchang":975,"Ġcor":976,"Ġchange":977,"Ġdecis":978,"ability":979,"hing":980,"atural":981,"ever":982,"Ġcost":983,"Ġgood":984,"ause":985,"Ġident":986,"Ġsoft":987,"ined":988,"Ġpass":989,"'t":990,"atures":991,"Ġben":992,"Ġcompany":993,"Ġstart":994,"Ġsignificant":995,"Ġsumm":996,"ond":997,"old":998,"bers":999,"sel":1000,"?\\":1001,"Ġcur":1002,"Ġlight":1003,"Ġcommon":1004,".\\\"":1005,"Ġcustomers":1006,"iving":1007,"conom":1008,"Ġfunction":1009,"Ġve":1010,"Ġthree":1011,"Ġeven":1012,"ining":1013,"Ġgener":1014,"ries":1015,"Ġlevel":1016,"Ġspecific":1017,"Ġwebs":1018,"Ġthen":1019,"Ġeffective":1020,"cur":1021,"ense":1022,"Ġlarge":1023,"Ġdist":1024,"Ġeffic":1025,"Ġsupport":1026,"Ġget":1027,"Create":1028,"read":1029,"port":1030,"Ġinf":1031,"Ġ'":1032,"Ġyear":1033,"Ġstate":1034,"Ġkey":1035,"ccess":1036,":**":1037,"Ġav":1038,"Ġknow":1039,"Ġbenef":1040,"Ġess":1041,"ables":1042,"ren":1043,"Ġown":1044,"ĠThese":1045,"ock":1046,"-t":1047,"Ġide":1048,"omm":1049,"reen":1050,"ced":1051,"cture":1052,"Ġteam":1053,"Ġris":1054,"Ġtasks":1055,"Ġdown":1056,"Ġstru":1057,"Ġcomputer":1058,"-b":1059,"Ġfact":1060,"Ġmem":1061,"etter":1062,"\\nS":1063,"Ġaround":1064,"Ġword":1065,"Ġbased":1066,"Ġbeh":1067,"Ġright":1068,"Ġdel":1069,"Ġpoint":1070,"Ġnatural":1071,"ss":1072,"Ġeconom":1073,"Ġmade":1074,"Ġins":1075,"Ġinst":1076,"Ġmat":1077,"Ġvalue":1078,"Ġanim":1079,"Ġsever":1080,"\\nT":1081,"ational":1082,"ital":1083,"ze":1084,"ote":1085,"ills":1086,"tern":1087,"Ġread":1088,"Ġcontent":1089,"Ġonline":1090,"Ġend":1091,"ĠUn":1092,"vent":1093,"Ġsee":1094,"ending":1095,"Ġmon":1096,"Ġdr":1097,"Ġkeep":1098,"Ġsystems":1099,"cul":1100,"ven":1101,"Ġstory":1102,"Ġmedia":1103,"Ġseveral":1104,"hen":1105,"ateg":1106,"Ġcontin":1107,"Ġdev":1108,"Ġlearn":1109,"Ġla":1110,"Ġstre":1111,"Ġpartic":1112,"Ġair":1113,"ually":1114,"Ġsuccess":1115,"ouse":1116,"Ġiss":1117,"ied":1118,"Ġmachine":1119,"Ġopt":1120,"Ġx":1121,"Ġop":1122,"Ġprof":1123,"ocus":1124,"chie":1125,"Ġmeth":1126,"ner":1127,"omp":1128,"ron":1129,"Ġhome":1130,"Ġbetter":1131,"ĠPro":1132,"Ġmult":1133,"omet":1134,"Ġincrease":1135,"Ġanaly":1136,"vert":1137,"Ġrele":1138,"Ġbra":1139,"ink":1140,"Ġtem":1141,"Ġpredi":1142,"Ġtre":1143,"Ġservice":1144,"Ġwebsite":1145,"Ġmanage":1146,"Ġsoftware":1147,"here":1148,"Ġprot":1149,"-s":1150,"Ġquest":1151,"ier":1152,"Ġknown":1153,"Ġorder":1154,"Ġphys":1155,"cept":1156,"Ġachie":1157,"Ġinput":1158,"Ġpossible":1159,"ĠIf":1160,"Ġext":1161,"fter":1162,"Ġelect":1163,"Ġmethod":1164,"Ġbre":1165,"ĠAn":1166,"ways":1167,"ering":1168,"ets":1169,"Ġjust":1170,"Ġstore":1171,"Ġdevelopment":1172,"Ġcare":1173,"Ġobject":1174,"Ġtype":1175,"ĠFor":1176,"Ġfocus":1177,"ggest":1178,"Ġonly":1179,"Ġconsid":1180,"ars":1181,"Ġchall":1182,"Ġdeterm":1183,"Ġsal":1184,"ins":1185,"Ġfeatures":1186,"Ġtru":1187,"ody":1188,"Ġtool":1189,">\\":1190,"Ġensure":1191,"oss":1192,"ublic":1193,"Ġitem":1194,"Here":1195,"ination":1196,"Ġdef":1197,"Describe":1198,"ional":1199,"roup":1200,"Ġconf":1201,"Ġneeds":1202,"Ġcharacter":1203,"Ġvarious":1204,"Ġlet":1205,"Ġapplic":1206,"aut":1207,"Ġjob":1208,"ellig":1209,"ĠCon":1210,"Ġbest":1211,"Ġfore":1212,"Ġamount":1213,"rop":1214,"Ġbuild":1215,"ique":1216,"aging":1217,"Ġemploy":1218,"Ġrest":1219,"air":1220,"What":1221,"Ġtoget":1222,"Ġways":1223,"Ġidentify":1224,"Ġtogether":1225,"Ġreal":1226,"Ġusers":1227,"Ġmean":1228,"asing":1229,"ĠAm":1230,"Ġeduc":1231,"Ġalgorithm":1232,"Ġnetw":1233,"Ġcode":1234,"Write":1235,"ov":1236,"-d":1237,"oura":1238,"ĠHowever":1239,"uture":1240,"view":1241,"Ġindu":1242,"Ġproducts":1243,"ected":1244,"ertain":1245,";\\":1246,"ĠAs":1247,"pr":1248,"aste":1249,"Ġoper":1250,"Ġ$":1251,"avi":1252,"self":1253,"Ġ<":1254,"Ġindust":1255,"Ġgu":1256,"Ġothers":1257,"Ex":1258,"ian":1259,"Ġ\"\\\"":1260,"-f":1261,"nces":1262,"Ġfil":1263,"Ġrespons":1264,"rol":1265,"Ġcap":1266,"Ġbefore":1267,"vern":1268,"Ġcomplex":1269,"lus":1270,"ribut":1271,"ats":1272,"Ġpositive":1273,"oh":1274,"Ġlo":1275,"Ġgroup":1276,"Ġfound":1277,"ee":1278,"ogn":1279,"Ġsw":1280,"Ġindividuals":1281,"Ġpract":1282,"Ġenc":1283,"Ġshare":1284,"raph":1285,"Ġrange":1286,"Ġsun":1287,"\\t":1288,"Ġproviding":1289,"icle":1290,"Ġdem":1291,"Ġplace":1292,"Ġaud":1293,"joy":1294,"Ġmust":1295,"els":1296,"ery":1297,"One":1298,"Ġfamily":1299,"Ġfuture":1300,"less":1301,"rent":1302,"Ġproblem":1303,"Ġessential":1304,"rodu":1305,"ired":1306,"Ġreducing":1307,"ism":1308,"Ġwarm":1309,"ray":1310,"Ġability":1311,"Ġstrong":1312,"Ġalways":1313,"Ġresources":1314,"Ġbenefits":1315,"Ġstrateg":1316,"Ġinvolves":1317,"Ġassist":1318,"erest":1319,"nA":1320,"ression":1321,"Ġ[":1322,"ilities":1323,"Ġsteps":1324,"verall":1325,"Ġshow":1326,"obal":1327,"\\nF":1328,"Ġland":1329,"ĠHere":1330,"Ġbusinesses":1331,"ĠEn":1332,"pportun":1333,"Ġmeas":1334,"Ġreturn":1335,"Ġdig":1336,"Ġhist":1337,"yth":1338,"Ġcent":1339,"Ġable":1340,"Ġwithout":1341,"yc":1342,"plain":1343,"Ġrelations":1344,"Ġservices":1345,"-c":1346,"Ġtest":1347,"arth":1348,"Ġcommunication":1349,"Ġintern":1350,"new":1351,"Ġsit":1352,"Ġinvest":1353,"Ġcaus":1354,"Ġunt":1355,"Ġfriends":1356,"Ġchanges":1357,"cri":1358,"dit":1359,"ĠBy":1360,"ĠYou":1361,"Ġmeans":1362,"Ġrese":1363,"ool":1364,"ted":1365,"elligence":1366,"ains":1367,"pping":1368,"Ġbel":1369,"Ġrepresent":1370,"Ġhapp":1371,"Ġser":1372,"Ġperformance":1373,"Ġopportun":1374,"Ġtemper":1375,"ĠShe":1376,"Ġfu":1377,"ix":1378,"bot":1379,"Ġwrit":1380,"Ġbehavi":1381,"Ġproject":1382,"ĠWith":1383,"ivers":1384,"day":1385,"Ġphysical":1386,"izing":1387,"Ġactiv":1388,"Ġwithin":1389,"Ġinterest":1390,"olution":1391,"wards":1392,"ffic":1393,"Ġquick":1394,"Ġpublic":1395,"Ġgrowth":1396,"Ġcho":1397,"Ġrelationship":1398,"Ġuntil":1399,"Ġhelps":1400,"Ġstudents":1401,"Ġfiel":1402,"imes":1403,"ulation":1404,"ibility":1405,"elf":1406,"Ġful":1407,"Ġsub":1408,"ank":1409,"ides":1410,"Ġskills":1411,"Ġclimate":1412,"Given":1413,"Ġpar":1414,"Ġclear":1415,"irt":1416,"Name":1417,"Ġpresent":1418,"Ġtri":1419,"Ġchalleng":1420,"ream":1421,"Ġlay":1422,"Ġmarketing":1423,"Ġsummary":1424,"Ġchild":1425,"Ġsaf":1426,"Ġsure":1427,"Ġsame":1428,"Ġmu":1429,"Ġemail":1430,"bon":1431,"Ġsomet":1432,"```\\":1433,"Ġcurrent":1434,"amp":1435,"ences":1436,"ĠRe":1437,"Ġtransport":1438,"me":1439,"-p":1440,"action":1441,"ĠEx":1442,"Ġyears":1443,"Ġcomb":1444,"hor":1445,"anced":1446,"ty":1447,"Ġlove":1448,"Ġgreen":1449,"Ġpopular":1450,"Ġless":1451,"Ġdra":1452,"Ġcontrol":1453,"Ġaff":1454,"Ġconsum":1455,"Ġgame":1456,"ental":1457,"ights":1458,"arget":1459,"omes":1460,"ox":1461,"icult":1462,"erc":1463,"Ġgoals":1464,"ancial":1465,"tle":1466,"Ġgovern":1467,"Ġnumbers":1468,"Ġfive":1469,"Ġstand":1470,"Ġsearch":1471,"Ġefficient":1472,"Ġwal":1473,"Ġname":1474,"ath":1475,"Ġheart":1476,"Ġduring":1477,"rect":1478,"Ġoverall":1479,"ython":1480,"Ġallows":1481,"Ġcity":1482,"ave":1483,"vant":1484,"aterial":1485,"Ġwide":1486,"Ġmus":1487,"ificial":1488,"Ġhard":1489,"ĠTh":1490,"oose":1491,"Ġglobal":1492,"aj":1493,"Ġter":1494,"Ġdifficult":1495,"Ġline":1496,"ĠAl":1497,"care":1498,"ived":1499,"Ġregular":1500,"Ġgr":1501,"),":1502,"lement":1503,"Ġhim":1504,"Ġunique":1505,"Ġenjoy":1506,"Ġmeaning":1507,"Ġopen":1508,"Ġi":1509,"abor":1510,"Ġarea":1511,"Ġitems":1512,"Ġclean":1513,"ditionally":1514,"oid":1515,"ĠWe":1516,"Ġbeaut":1517,"Ġmeet":1518,"iple":1519,"Ġstatement":1520,"Ġagain":1521,"ysis":1522,"Ġfac":1523,"Ġsources":1524,"Ġbody":1525,"Ġalgorithms":1526,"Ġaudience":1527,"Ġwant":1528,"Ġlog":1529,"Ġmaintain":1530,"Ġactivities":1531,"Ġmove":1532,"Ġcult":1533,"oney":1534,"Ġtarget":1535,"\\nB":1536,"Ġmaterial":1537,"Ġcreating":1538,"Ġstructure":1539,"atform":1540,"ext":1541,"Ġexperien":1542,"Ġvalues":1543,"ead":1544,"ohn":1545,"Ġhealthy":1546,"ross":1547,"Ġinteg":1548,"Ġresearch":1549,"atch":1550,"ooking":1551,"Ġrole":1552,"Ġprovides":1553,"iety":1554,"ists":1555,"Ġfinancial":1556,"ories":1557,"dent":1558,"Ġer":1559,"Ġarticle":1560,"Ġelements":1561,"Ġaddress":1562,"Ġconn":1563,"ĠUse":1564,"mp":1565,"Ġeasy":1566,"Ġneg":1567,"Ġcolor":1568,"Ġcalcul":1569,"Explain":1570,"ĠPl":1571,"pect":1572,"ince":1573,"ale":1574,"Ġrisk":1575,"curity":1576,"ert":1577,"Ġfeed":1578,"Ġevent":1579,"vers":1580,"ples":1581,"Ġlevels":1582,"Ġbi":1583,"Ġstay":1584,"Ġplatform":1585,"Ġbreak":1586,"back":1587,"Ġsat":1588,"\\nOverall":1589,"Ġeducation":1590,"\\nC":1591,"Ġcarbon":1592,"--------":1593,"ape":1594,"Ġprevent":1595,"Ġaddition":1596,"Ġstress":1597,"ral":1598,"ource":1599,"rus":1600,"Ġcome":1601,"Ġrecogn":1602,"ĠUnited":1603,"Ġproper":1604,"Ġpoll":1605,"dentify":1606,"Ġunderstanding":1607,"Ġdecisions":1608,"ict":1609,"Ġdire":1610,"Ġbehavior":1611,"Ġ*":1612,"\\nI":1613,"Ġmess":1614,"Ġanimals":1615,"Ġsl":1616,"Ġwind":1617,"Ġbas":1618,"Ġpain":1619,"Ġleading":1620,"ern":1621,"ger":1622,"Ġpres":1623,"Ġthough":1624,"Ġinteract":1625,"yle":1626,"Ġdoes":1627,"Ġhead":1628,"Ġintelligence":1629,"orts":1630,"Ġbecome":1631,"Ġrun":1632,"aring":1633,"Ġimplement":1634,"Ġaction":1635,"oot":1636,"terns":1637,"Ġprotect":1638,"eric":1639,"Ġflow":1640,"Ġemot":1641,"cessary":1642,"urate":1643,"Ġsuggest":1644,"Ġprogram":1645,"Ġphr":1646,"Ġhealthcare":1647,"ention":1648,"Ġsust":1649,"Ġwhy":1650,"Ġaccurate":1651,"lu":1652,"Ġhig":1653,"Ġreach":1654,"Ġallowing":1655,"Ġtravel":1656,"Ġrequire":1657,"Ġareas":1658,"Ġdeep":1659,"He":1660,"Ġfew":1661,"Ġself":1662,"oun":1663,"Ġ#":1664,"osp":1665,"str":1666,"Ġminut":1667,"Ġdecision":1668,"ĠThere":1669,"ances":1670,"Ġquality":1671,"Ġavail":1672,"Ġspace":1673,"Ġsomething":1674,"Ġweb":1675,"Ġpatterns":1676,"Ġmot":1677,"oring":1678,"isf":1679,"Ġanother":1680,"Ġaccount":1681,"\\nW":1682,"uss":1683,"Ġmaj":1684,"uation":1685,"Ġsustain":1686,"Ġautom":1687,"iques":1688,"issions":1689,"verse":1690,"Ġconcept":1691,"Ġsecurity":1692,"Ġthose":1693,"Ġprofess":1694,"Ġshort":1695,"Ġnight":1696,"ength":1697,"apt":1698,"ex":1699,"ĠAdditionally":1700,"Ġtaking":1701,"Ġtoo":1702,"agn":1703,"Ġsimple":1704,"lusion":1705,"iency":1706,"ash":1707,"ours":1708,"Ġpa":1709,"Ġlit":1710,"ĠSp":1711,"iting":1712,"Ġdon":1713,"Ġlim":1714,"lish":1715,"mat":1716,"aves":1717,"ledge":1718,"ditional":1719,"inc":1720,"Ġevents":1721,"Ġoffer":1722,"thing":1723,"Ġworking":1724,"Ġanalysis":1725,"Ġachieve":1726,"Ġpie":1727,"Ġbook":1728,"Ġfre":1729,"Ġmuch":1730,"oon":1731,"Ġtry":1732,"esp":1733,"Ġwaste":1734,"face":1735,"Ġear":1736,"Ġfru":1737,"Ġtransportation":1738,"chool":1739,"Ġtechniques":1740,"Ġprogramm":1741,"ĠEarth":1742,"Ġpredict":1743,"Ġnever":1744,"ws":1745,"ument":1746,"imately":1747,"ared":1748,"Ġparticular":1749,"Ġtowards":1750,"Ġeconomic":1751,"Ġincreasing":1752,"Ġfast":1753,"iment":1754,"Ġnetwork":1755,"Ġcorrect":1756,"Ġmight":1757,"Ġoc":1758,"Ġbecause":1759,"ĠWh":1760,"az":1761,"play":1762,"Ġresults":1763,"Ġmanagement":1764,"Ġpurch":1765,"Ġsound":1766,"Ġpast":1767,"Ġtraining":1768,"____":1769,"ope":1770,"Ġengage":1771,"ourage":1772,"Ġsense":1773,"Ġfree":1774,"Ġpref":1775,"ees":1776,"Ġcountries":1777,"ney":1778,"anies":1779,"Ġafter":1780,"Ġmind":1781,"Ġexc":1782,"ĠOnce":1783,"ĠĠĠĠĠĠĠĠĠĠĠ":1784,"Ġcomplete":1785,"Ġimm":1786,"Ġest":1787,"Ġgenerate":1788,"verb":1789,"ĠDe":1790,"'m":1791,"Ġtools":1792,"redients":1793,"Ġmajor":1794,"ently":1795,"Ġcontribut":1796,"leep":1797,"Ġpoints":1798,"ditions":1799,"Ġfactors":1800,"Ġel":1801,"Ġnext":1802,"ium":1803,"oud":1804,"Ġcru":1805,"Ġreas":1806,"riate":1807,"ĠInd":1808,"Ġpromot":1809,"Ġhistory":1810,"Ġjour":1811,"Ġdue":1812,"Con":1813,"Ġveget":1814,"ency":1815,"ĠAmeric":1816,"Ġfra":1817,"Ġdifference":1818,"oard":1819,"lex":1820,"Ġequation":1821,"irtual":1822,"Ġcup":1823,"Ġforest":1824,"Ġnegative":1825,"Ġsecon":1826,"ones":1827,"Ġnature":1828,"Ġuses":1829,"ah":1830,"por":1831,"Ġsec":1832,"ording":1833,"Ġlast":1834,"ĠSome":1835,"Ġissues":1836,"Ġscient":1837,"Ġprint":1838,"ĠStates":1839,"over":1840,"Ġsatisf":1841,"Ġdevices":1842,"Ġdise":1843,"Ġtemperature":1844,"Ġfeedback":1845,"Ġnecessary":1846,"Ġemissions":1847,"mb":1848,"Ġlow":1849,"for":1850,"tal":1851,"Ġchallenges":1852,"Ġarray":1853,"Ġside":1854,"Ġengine":1855,"Ġboo":1856,"ata":1857,"Ġbelie":1858,"-m":1859,"Ġmultiple":1860,"Ġsing":1861,"Ġgovernment":1862,"ames":1863,"ified":1864,"Ġminutes":1865,"Ġsuccessful":1866,"Ġmoney":1867,"Ġquickly":1868,"Ġbir":1869,"Ġtypically":1870,"Ġpost":1871,"Ġprep":1872,"Ġknowledge":1873,"pped":1874,"actions":1875,"Ġmethods":1876,"Ġoptim":1877,"\\nP":1878,"Ġoutput":1879,"Ġfield":1880,"Ġtable":1881,"Ġbal":1882,"Ġcoll":1883,"Ġcharacters":1884,"volution":1885,"ords":1886,"ilar":1887,"ification":1888,"ane":1889,"Ġcell":1890,"Ġmil":1891,"ĠWhat":1892,"Ġsqu":1893,"Ġlives":1894,"ĠAr":1895,"Ġphrase":1896,"Ġnut":1897,"Ġdigital":1898,"Ġinternet":1899,"lass":1900,"ura":1901,"ommend":1902,"Ġtreat":1903,"Ġapprop":1904,"resh":1905,"urther":1906,"ĠOne":1907,"Ġvisual":1908,"ategor":1909,"Ġapproach":1910,"Ġcertain":1911,"Ġsho":1912,"val":1913,"Ġtask":1914,"ires":1915,"Ġappropriate":1916,"Ġvie":1917,"Ġdesigned":1918,"pose":1919,"**:":1920,"fort":1921,"Ġ|\\":1922,"Ġapplications":1923,"Ġpay":1924,"Ġnow":1925,"Ġheat":1926,"Ġindustry":1927,"pre":1928,"Ġeffectively":1929,"Ġpopulation":1930,"Ġopportunities":1931," pet.LaunchConfig: + """ + Initialize pet.LaunchConfig for single-node, multi-rank functions. + + Args: + nproc (int): The number of processes to launch. + + Returns: + An instance of pet.LaunchConfig for single-node, multi-rank functions. + + Example: + >>> from torch.distributed import launcher + >>> launch_config = get_pet_launch_config(nproc=8) + >>> launcher.elastic_launch(config=launch_config, entrypoint=train)() + """ + return pet.LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=nproc, + run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint="localhost:0", + max_restarts=0, + monitor_interval=1, + ) + + return get_pet_launch_config_fn + + +def pytest_addoption(parser: argparse.ArgumentParser) -> None: + parser.addoption( + "--large-scale", + type=bool, + default=False, + help="Run a larger scale integration test", + ) + parser.addoption( + "--silence-s3-logs", + action="store_true", + help="Silence progress bar when fetching assets from S3", + ) diff --git a/training/tests/recipes/__init__.py b/training/tests/recipes/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/recipes/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/recipes/common.py b/training/tests/recipes/common.py new file mode 100644 index 0000000..63eaa47 --- /dev/null +++ b/training/tests/recipes/common.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +RECIPE_TESTS_DIR = Path(__file__).parent diff --git a/training/tests/recipes/test_configs.py b/training/tests/recipes/test_configs.py new file mode 100644 index 0000000..f083543 --- /dev/null +++ b/training/tests/recipes/test_configs.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +from pathlib import Path + +import torchtune + +from omegaconf import OmegaConf +from torchao.utils import TORCH_VERSION_AFTER_2_4 +from torchtune import config + +CONFIG_DIR = Path(torchtune.__file__).parent.parent / "recipes" / "configs" + + +class TestConfigs: + def test_instantiate(self) -> None: + all_configs = [ + os.path.join(CONFIG_DIR, f) + for f in os.listdir(CONFIG_DIR) + if f.endswith(".yaml") + ] + for config_path in all_configs: + # QAT config is only compatible with PyTorch 2.4+ + if config_path.endswith("qat_full.yaml") and not TORCH_VERSION_AFTER_2_4: + continue + cfg = OmegaConf.load(config_path) + config.validate(cfg) diff --git a/training/tests/recipes/test_eleuther_eval.py b/training/tests/recipes/test_eleuther_eval.py new file mode 100644 index 0000000..4c1838d --- /dev/null +++ b/training/tests/recipes/test_eleuther_eval.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import builtins +import math +import re +import runpy +import sys +from pathlib import Path + +import pytest + +from tests.common import TUNE_PATH +from tests.recipes.utils import llama2_test_config +from tests.test_utils import CKPT_MODEL_PATHS + + +class TestEleutherEval: + @pytest.mark.integration_test + def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + + cmd = f""" + tune run eleuther_eval \ + --config eleuther_evaluation \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + limit=10 \ + dtype=fp32 \ + device=cpu \ + """.split() + + model_config = llama2_test_config() + cmd = cmd + model_config + + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out = capsys.readouterr().out + + # v0.4.2 format + # | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| + # |--------------|------:|------|-----:|------|-----:|---|-----:| + # |truthfulqa_mc2| 2|none | 0|acc |0.3469|± |0.1444| + + # v0.4.3 format + # | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr| + # |--------------|------:|------|-----:|------|---|-----:|---|-----:| + # |truthfulqa_mc2| 2|none | 0|acc |↑ |0.3469|± |0.1444| + + # The below RegEx command will pick up both formats + search_results = re.search( + r"acc(?:_norm)?\s*\|?\s*(?:\↑\s*\|?)?([\d.]+)", out.strip() + ) + assert search_results is not None + acc_result = float(search_results.group(1)) + assert math.isclose(acc_result, 0.3, abs_tol=0.05) + + @pytest.fixture + def hide_available_pkg(self, monkeypatch): + import_orig = builtins.__import__ + + def mocked_import(name, *args, **kwargs): + if name == "lm_eval": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + @pytest.mark.integration_test + @pytest.mark.usefixtures("hide_available_pkg") + def test_eval_recipe_errors_without_lm_eval(self, caplog, monkeypatch, tmpdir): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + + cmd = f""" + tune run eleuther_eval \ + --config eleuther_evalation \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + limit=10 \ + dtype=fp32 \ + device=cpu \ + """.split() + + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match="1"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + err_log = caplog.messages[-1] + assert "Recipe requires EleutherAI Eval Harness v0.4" in err_log diff --git a/training/tests/recipes/test_full_finetune_distributed.py b/training/tests/recipes/test_full_finetune_distributed.py new file mode 100644 index 0000000..848c9c9 --- /dev/null +++ b/training/tests/recipes/test_full_finetune_distributed.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy + +import sys +from pathlib import Path + +import pytest +import torch +from tests.common import TUNE_PATH + +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) + + +class TestFullFinetuneDistributedRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=4", + "dtype=fp32", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + "epochs=2", + "max_steps_per_epoch=2", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama2": [10.5136, 10.4813, 10.5088, 10.5250], + "llama3": [12.0673, 11.9072, 11.9302, 11.9355], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama2/7B_full", "llama2", "hf"), + ("llama3/8B_full", "llama3", "tune"), + ], + ) + @gpu_test(gpu_count=2) + def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + metric_logger.filename={log_file} \ + """.split() + model_config = MODEL_TEST_CONFIGS[model_type] + cmd = cmd + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values(model_type) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) diff --git a/training/tests/recipes/test_full_finetune_single_device.py b/training/tests/recipes/test_full_finetune_single_device.py new file mode 100644 index 0000000..7b88c84 --- /dev/null +++ b/training/tests/recipes/test_full_finetune_single_device.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import runpy + +import sys +from pathlib import Path + +import numpy as np + +import pytest + +import torch +from tests.common import TUNE_PATH + +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + TOKENIZER_PATHS, +) + + +class TestFullFinetuneSingleDeviceRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=8", + "device=cpu", + "dtype=fp32", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + "epochs=2", + "max_steps_per_epoch=2", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama2": [10.5201, 10.5217, 10.4945, 10.5136], + "llama3": [11.9839, 11.9684, 11.9596, 11.9366], + } + + return loss_values_map[model_type] + + @pytest.mark.integration_test + @pytest.mark.parametrize("compile", [True, False]) + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama2/7B_full_low_memory", "llama2", "meta"), + ("llama3/8B_full_single_device", "llama3", "tune"), + ], + ) + def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # To workaround https://github.com/pytorch/torchtune/issues/676 + if compile: + os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" + cmd = f""" + tune run full_finetune_single_device \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + metric_logger.filename={log_file} \ + compile={compile} \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type] + cmd = cmd + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values(model_type) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.integration_test + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt = "llama2_hf" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run full_finetune_single_device \ + --config llama2/7B_full_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2"] + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run full_finetune_single_device \ + --config llama2/7B_full_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = self._fetch_expected_loss_values("llama2")[2:] + + loss_values = get_loss_values_from_metric_logger(log_file) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) + + +class TestFullFinetuneSingleDeviceGradientAccumulation: + def _get_test_config_overrides(self): + return [ + "device=cpu", + "dtype=fp32", + "enable_activation_checkpointing=False", + "tokenizer.path=/tmp/test-artifacts/tokenizer.model", + "dataset=tests.recipes.utils.DummyDataset", + "dataset.train_on_input=False", + "seed=9", + "epochs=1", + "max_steps_per_epoch=1", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "optimizer_in_bwd=False", + ] + + @pytest.mark.integration_test + def test_gradient_accumulation(self, tmpdir, monkeypatch): + """Test whether gradient accumulation runs properly in the recipe. In general + the sum of loss across minibatches should equal the loss over the full batch, + but since our loss is normalized by the number of unmasked tokens, this does not + hold in for our case. We use a dummy dataset where all tokens are unmasked, and + in this test check that a single batch size of N yields the same loss as N accumulated + microbatches of size 1. + """ + full_batch_size = 4 + micro_batch_size = 1 + gradient_accumulation_steps = full_batch_size // micro_batch_size + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + no_grad_accum_log_file = gen_log_file_name(tmpdir, suffix="no_grad_accum") + grad_accum_log_file = gen_log_file_name(tmpdir, suffix="grad_accum") + + cmd_1 = f""" + tune run full_finetune_single_device \ + --config llama2/7B_full_low_memory \ + checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir={ckpt_dir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + batch_size={full_batch_size} \ + output_dir={tmpdir} \ + log_every_n_steps=1 \ + metric_logger.filename={no_grad_accum_log_file} \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2"] + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + no_accum_loss = get_loss_values_from_metric_logger(no_grad_accum_log_file)[ + 0 + ] # List of a single element + + # Update the cmd with new values for gradient accumulation + cmd_2 = f""" + tune run full_finetune_single_device \ + --config llama2/7B_full_low_memory \ + checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir={ckpt_dir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=llama2 \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + metric_logger.filename={grad_accum_log_file} \ + """.split() + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + accum_loss = np.mean(get_loss_values_from_metric_logger(grad_accum_log_file)) + torch.testing.assert_close(no_accum_loss, accum_loss, atol=1e-5, rtol=1e-5) diff --git a/training/tests/recipes/test_lora_finetune_distributed.py b/training/tests/recipes/test_lora_finetune_distributed.py new file mode 100644 index 0000000..2b7afd6 --- /dev/null +++ b/training/tests/recipes/test_lora_finetune_distributed.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from omegaconf import OmegaConf +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) +from torchtune import config + + +class TestLoRAFinetuneDistributedRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=4", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + "epochs=2", + "dtype=fp32", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "gradient_accumulation_steps=1", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + # These values have been validated against single device recipe test via + # https://gist.github.com/ebsmothers/f1c3db7c66655a23a91e0290360960c4 + loss_values_map = { + "llama2": [10.5136, 10.4856, 10.5292, 10.5345], + "llama3": [11.9325, 11.9325, 11.9325, 11.9369], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_loss(self, tmpdir, monkeypatch): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed + --config llama2/7B_lora \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + metric_logger.filename={log_file} \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values("llama2") + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama2/7B_lora", "llama2", "hf"), + ("llama3/8B_lora", "llama3", "tune"), + ], + ) + def test_training_state_on_resume( + self, config, model_type, ckpt_type, tmpdir, monkeypatch + ): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + expected_loss_values = self._fetch_expected_loss_values(model_type) + + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] + + loss_values = get_loss_values_from_metric_logger(log_file) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "recipe_config, model_type, ckpt_type", + [ + ("llama2/7B_lora", "llama2", "tune"), + ("llama3/8B_lora", "llama3", "tune"), + ], + ) + @gpu_test(gpu_count=2) + def test_save_and_load_merged_weights( + self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \ + --config {recipe_config} \ + output_dir={tmpdir} \ + model=torchtune.models.lora_small_test_model \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a base model + # and the base model weights + trained adapter weights in the LoRA model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base model for loading merged weights + base_config = MODEL_TEST_CONFIGS[model_type] + model = config.instantiate(OmegaConf.from_dotlist(base_config).model) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into model and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + model.load_state_dict(sd) + merged_ckpt_out = model(inputs) + + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) diff --git a/training/tests/recipes/test_lora_finetune_fsdp2.py b/training/tests/recipes/test_lora_finetune_fsdp2.py new file mode 100644 index 0000000..04cec6c --- /dev/null +++ b/training/tests/recipes/test_lora_finetune_fsdp2.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from omegaconf import OmegaConf +from packaging import version +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) +from torchtune import config + + +class TestLoRAFinetuneDistributedRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=4", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + "epochs=2", + "dtype=fp32", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "gradient_accumulation_steps=1", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + # These values have been validated against single device recipe test via + # https://gist.github.com/ebsmothers/f1c3db7c66655a23a91e0290360960c4 + loss_values_map = { + "llama2": [10.5136, 10.4856, 10.5292, 10.5345], + "llama3": [11.9325, 11.9325, 11.9325, 11.9369], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", reason="" + ) + @gpu_test(gpu_count=2) + def test_loss(self, tmpdir, monkeypatch): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 + --config llama2/7B_lora \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + metric_logger.filename={log_file} \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values("llama2") + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", reason="" + ) + @gpu_test(gpu_count=2) + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama2/7B_lora", "llama2", "hf"), + ("llama3/8B_lora", "llama3", "tune"), + ], + ) + def test_training_state_on_resume( + self, config, model_type, ckpt_type, tmpdir, monkeypatch + ): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + expected_loss_values = self._fetch_expected_loss_values(model_type) + + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] + + loss_values = get_loss_values_from_metric_logger(log_file) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", reason="" + ) + @pytest.mark.parametrize( + "recipe_config, model_type, ckpt_type", + [ + ("llama2/7B_lora", "llama2", "tune"), + ("llama3/8B_lora", "llama3", "tune"), + ], + ) + @gpu_test(gpu_count=2) + def test_save_and_load_merged_weights( + self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 \ + --config {recipe_config} \ + output_dir={tmpdir} \ + model=torchtune.models.lora_small_test_model \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a base model + # and the base model weights + trained adapter weights in the LoRA model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base model for loading merged weights + base_config = MODEL_TEST_CONFIGS[model_type] + model = config.instantiate(OmegaConf.from_dotlist(base_config).model) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into model and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + model.load_state_dict(sd) + merged_ckpt_out = model(inputs) + + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) diff --git a/training/tests/recipes/test_lora_finetune_single_device.py b/training/tests/recipes/test_lora_finetune_single_device.py new file mode 100644 index 0000000..1527530 --- /dev/null +++ b/training/tests/recipes/test_lora_finetune_single_device.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from omegaconf import OmegaConf +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + TOKENIZER_PATHS, +) +from torchtune import config +from torchtune.utils import torch_version_ge + + +class TestLoRAFinetuneSingleDeviceRecipe: + def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): + return [ + "batch_size=8", + "device=cpu", + f"dtype={dtype_str}", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + f"epochs={epochs}", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "gradient_accumulation_steps=1", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama2": [10.5209, 10.5269, 10.5130, 10.5242], + "llama3": [11.9838, 11.9691, 11.9616, 11.9383], + } + return loss_values_map[model_type] + + def _fetch_qlora_expected_loss_values(self, dtype): + if dtype == "bf16": + return [10.5197, 10.5272, 10.5129, 10.5243] + return [10.5198, 10.5271, 10.5131, 10.5244] + + @pytest.mark.integration_test + @pytest.mark.parametrize("compile", [True, False]) + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama2/7B_lora_single_device", "llama2", "meta"), + ("llama3/8B_lora_single_device", "llama3", "tune"), + ], + ) + def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # To workaround https://github.com/pytorch/torchtune/issues/676 + if compile: + os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" + + cmd = f""" + tune run lora_finetune_single_device \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + metric_logger.filename={log_file} \ + compile={compile} \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd = cmd + self._get_test_config_overrides(dtype_str="fp32") + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values(model_type) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @pytest.mark.parametrize("dtype", ["fp32", "bf16"]) + @pytest.mark.parametrize("compile", [True, False]) + @pytest.mark.skipif( + not torch_version_ge("2.4.0"), + reason="Please install a nightly build of torch to run this test.", + ) + def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): + ckpt = "llama2_meta" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # To workaround https://github.com/pytorch/torchtune/issues/676 + if compile: + os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" + + cmd = f""" + tune run lora_finetune_single_device + --config llama2/7B_qlora_single_device \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelMetaCheckpointer + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + metric_logger.filename={log_file} \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + compile={compile} \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2_qlora"] + + cmd = cmd + self._get_test_config_overrides(dtype_str=dtype) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_qlora_expected_loss_values(dtype=dtype) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.integration_test + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt = "llama2_hf" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2_lora"] + + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + """.split() + cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Second epoch only + expected_loss_values = self._fetch_expected_loss_values("llama2")[2:] + loss_values = get_loss_values_from_metric_logger(log_file)[:2] + + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + + cmd = f""" + tune run lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a Llama2 base model + # and the base model weights + trained adapter weights in the LoRA Llama 2 model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base llama2 model for loading merged weights + base_llama2_config = MODEL_TEST_CONFIGS["llama2"] + llama2_model = config.instantiate( + OmegaConf.from_dotlist(base_llama2_config).model + ) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into llama2 and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + llama2_model.load_state_dict(sd) + merged_ckpt_out = llama2_model(inputs) + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) diff --git a/training/tests/recipes/test_ppo_full_tunetune_single_device.py b/training/tests/recipes/test_ppo_full_tunetune_single_device.py new file mode 100644 index 0000000..7113a12 --- /dev/null +++ b/training/tests/recipes/test_ppo_full_tunetune_single_device.py @@ -0,0 +1,373 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from tests.common import TUNE_PATH + +from tests.recipes.utils import ( + dummy_text_completion_alpaca_dataset_config, + llama2_classifier_test_config, + llama2_test_config, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, +) + + +class TestPPOFullFinetuneSingleDeviceRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=4", + "forward_batch_size=4", + "ppo_batch_size=4", + "ppo_epochs=1", + "num_steps=16", + "temperature=1.0", + "gradient_accumulation_steps=1", + "device=cpu", + "dtype=fp32", + "enable_activation_checkpointing=False", + "tokenizer.path=/tmp/test-artifacts/tokenizer.model", + "tokenizer._component_=torchtune.models.llama2.llama2_tokenizer", + "seed=9", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + ] + dummy_text_completion_alpaca_dataset_config() + + @pytest.mark.integration_test + def test_loss(self, tmpdir, monkeypatch): + + reward_ckpt = "llama2_reward_hf" + policy_ckpt = "llama2_hf" + reward_ckpt_path = Path(CKPT_MODEL_PATHS[reward_ckpt]) + policy_ckpt_path = Path(CKPT_MODEL_PATHS[policy_ckpt]) + + ckpt_dir = policy_ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + policy_tmpdir = (tmpdir / "policy").mkdir() + value_tmpdir = (tmpdir / "value").mkdir() + + write_hf_ckpt_config(ckpt_dir) + cmd_1 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{ckpt_dir}' \ + value_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + """.split() + + model_config = llama2_test_config() + model_config = [k.replace("model.", "policy_model.") for k in model_config] + model_config += ["policy_model.intermediate_dim=null"] + + reward_and_value_model_config = llama2_classifier_test_config() + reward_and_value_model_config = [ + k.replace("model.", "reward_and_value_model.") + for k in reward_and_value_model_config + ] + reward_and_value_model_config += [ + "reward_and_value_model.intermediate_dim=null" + ] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = [ + 1.0403, + 0.9495, + 0.9084, + 1.0494, + 0.9609, + 0.8846, + 1.0282, + 0.9390, + 0.8915, + 1.0166, + 0.9231, + 0.9352, + ] + torch.testing.assert_close( + loss_values, expected_loss_values, atol=1e-4, rtol=1e-5 + ) + + @pytest.mark.integration_test + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state correctly saved and restored after training.""" + + reward_ckpt = "llama2_reward_hf" + policy_ckpt = "llama2_hf" + reward_ckpt_path = Path(CKPT_MODEL_PATHS[reward_ckpt]) + policy_ckpt_path = Path(CKPT_MODEL_PATHS[policy_ckpt]) + + ckpt_dir = policy_ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + policy_tmpdir = (tmpdir / "policy").mkdir() + value_tmpdir = (tmpdir / "value").mkdir() + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(policy_tmpdir) + write_hf_ckpt_config(value_tmpdir) + # There are 4 steps in total (num_steps / batch size) + # and the dataset has 8 samples, so each epoch will be 2 batches + # a single step is a single batch update, and we checkpoint at every epoch (2 steps) + # so we're expecting an intermediate checkpoint at step 2. The idea here is to train for 4 steps, + # resume after 2, and ensure the losses for the final two steps after resuming are identical + cmd_1 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{ckpt_dir}' \ + value_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + """.split() + + model_config = llama2_test_config() + model_config = [k.replace("model.", "policy_model.") for k in model_config] + model_config += ["policy_model.intermediate_dim=null"] + + reward_and_value_model_config = llama2_classifier_test_config() + reward_and_value_model_config = [ + k.replace("model.", "reward_and_value_model.") + for k in reward_and_value_model_config + ] + reward_and_value_model_config += [ + "reward_and_value_model.intermediate_dim=null" + ] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + + # Resume training at step 2 + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + cmd_2 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{policy_tmpdir}' \ + checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{value_tmpdir}' \ + value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, "hf_model_0001_0.pt")}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + resume_from_checkpoint=True \ + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={resumed_log_file} \ + """.split() + + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + + # losses at each step are (loss, policy_loss, value_loss) + torch.testing.assert_close( + loss_values[6:], resumed_loss_values, rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.integration_test + def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatch): + """Test whether the recipe state correctly saves and restores optimizer state + when using ``optimizer_in_bwd``, since the optimizer checkpoint dict will include + parameters for two models. + + This is identical to ``test_training_state_on_resume``, but adds optimizer_in_bwd. + """ + + reward_ckpt = "llama2_reward_hf" + policy_ckpt = "llama2_hf" + reward_ckpt_path = Path(CKPT_MODEL_PATHS[reward_ckpt]) + policy_ckpt_path = Path(CKPT_MODEL_PATHS[policy_ckpt]) + + ckpt_dir = policy_ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + policy_tmpdir = (tmpdir / "policy").mkdir() + value_tmpdir = (tmpdir / "value").mkdir() + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(policy_tmpdir) + write_hf_ckpt_config(value_tmpdir) + cmd_1 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{ckpt_dir}' \ + value_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + + optimizer_in_bwd=True + """.split() + + model_config = llama2_test_config() + model_config = [k.replace("model.", "policy_model.") for k in model_config] + model_config += ["policy_model.intermediate_dim=null"] + + reward_and_value_model_config = llama2_classifier_test_config() + reward_and_value_model_config = [ + k.replace("model.", "reward_and_value_model.") + for k in reward_and_value_model_config + ] + reward_and_value_model_config += [ + "reward_and_value_model.intermediate_dim=null" + ] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + + # Resume training at step 2 + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + cmd_2 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{policy_tmpdir}' \ + checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{value_tmpdir}' \ + value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, "hf_model_0001_0.pt")}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + resume_from_checkpoint=True \ + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={resumed_log_file} \ + + optimizer_in_bwd=True + """.split() + + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + + # losses at each step are (loss, policy_loss, value_loss) + torch.testing.assert_close( + loss_values[6:], resumed_loss_values, rtol=1e-4, atol=1e-4 + ) diff --git a/training/tests/recipes/test_qat_distributed.py b/training/tests/recipes/test_qat_distributed.py new file mode 100644 index 0000000..e9357fb --- /dev/null +++ b/training/tests/recipes/test_qat_distributed.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy + +import sys +from pathlib import Path + +import pytest +import torch +from tests.common import TUNE_PATH + +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) +from torchao.utils import TORCH_VERSION_AFTER_2_4 + + +class TestQATDistributedRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=4", + "dtype=fp32", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + "epochs=2", + "max_steps_per_epoch=2", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama2": [10.5164, 10.4830, 10.5138, 10.5199], + "llama3": [12.0672, 11.9067, 11.9304, 11.9351], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama2/7B_qat_full", "llama2", "hf"), + ("llama3/8B_qat_full", "llama3", "tune"), + ], + ) + @gpu_test(gpu_count=2) + @pytest.mark.skipif( + not TORCH_VERSION_AFTER_2_4, reason="QAT only supported for PyTorch 2.4+" + ) + def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 qat_distributed \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + metric_logger.filename={log_file} \ + """.split() + model_config = MODEL_TEST_CONFIGS[model_type] + cmd = cmd + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values(model_type) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) diff --git a/training/tests/recipes/utils.py b/training/tests/recipes/utils.py new file mode 100644 index 0000000..0e4fb64 --- /dev/null +++ b/training/tests/recipes/utils.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from pathlib import Path +from typing import List + +import torch +from torch.utils.data import Dataset + +CKPT_COMPONENT_MAP = { + "tune": "torchtune.utils.FullModelTorchTuneCheckpointer", + "meta": "torchtune.utils.FullModelMetaCheckpointer", + "hf": "torchtune.utils.FullModelHFCheckpointer", +} + + +class DummyDataset(Dataset): + def __init__(self, *args, **kwargs): + self._data = torch.LongTensor( + [ + [0, 2, 4, 2, 5, 6, 7, 8, 9, 1, 2, 4, 3, 3, 5, 6, 8, 2, 1, 1], + [1, 2, 5, 6, 7, 8, 2, 3, 1, 9, 9, 9, 5, 6, 7, 0, 0, 0, 1, 2], + [5, 6, 8, 2, 1, 0, 3, 4, 0, 0, 0, 2, 4, 7, 8, 8, 2, 2, 1, 0], + [4, 6, 7, 1, 0, 2, 0, 2, 0, 2, 3, 9, 9, 9, 7, 5, 1, 8, 4, 1], + ] + ) + self._labels = torch.LongTensor( + [ + [2, 6, 7, 8, 2, 2, 1, 0, 0, 1], + [1, 2, 5, 6, 7, 8, 2, 3, 1, 9], + [6, 1, 1, 2, 5, 0, 9, 0, 2, 1], + [5, 8, 6, 0, 2, 0, 0, 3, 2, 1], + ] + ) + + def __getitem__(self, index): + return {"tokens": self._data[index], "labels": self._labels[index]} + + def __len__(self): + return len(self._data) + + +def get_assets_path(): + return Path(__file__).parent.parent / "assets" + + +def dummy_alpaca_dataset_config(): + data_files = os.path.join(get_assets_path(), "alpaca_tiny.json") + out = [ + "dataset._component_=torchtune.datasets.instruct_dataset", + "dataset.source='json'", + f"dataset.data_files={data_files}", + "dataset.template=torchtune.data.AlpacaInstructTemplate", + "dataset.split='train'", + ] + return out + + +def dummy_text_completion_alpaca_dataset_config(): + """ + Constructs a minimal text-completion-style dataset from ``alpaca_tiny.json``. + This is used for testing PPO fine-tuning. + """ + data_files = os.path.join(get_assets_path(), "alpaca_tiny.json") + out = [ + "dataset._component_=torchtune.datasets.text_completion_dataset", + "dataset.source='json'", + f"dataset.data_files={data_files}", + "dataset.column='instruction'", + "dataset.split='train[:10%]'", # 10% of the dataset gets us 8 batches + "dataset.max_seq_len=64", + "dataset.add_eos=False", + ] + return out + + +def llama2_test_config() -> List[str]: + return [ + "model._component_=torchtune.models.llama2.llama2", + "model.vocab_size=32_000", + "model.num_layers=4", + "model.num_heads=16", + "model.embed_dim=256", + "model.max_seq_len=2048", + "model.norm_eps=1e-5", + "model.num_kv_heads=8", + ] + + +def llama2_classifier_test_config() -> List[str]: + return [ + "model._component_=torchtune.models.llama2.llama2_classifier", + "model.num_classes=1", + "model.vocab_size=32_000", + "model.num_layers=4", + "model.num_heads=16", + "model.embed_dim=256", + "model.max_seq_len=2048", + "model.norm_eps=1e-5", + "model.num_kv_heads=8", + ] + + +def llama3_test_config() -> List[str]: + return [ + "model._component_=torchtune.models.llama3.llama3", + "model.vocab_size=128_256", + "model.num_layers=2", + "model.num_heads=8", + "model.embed_dim=64", + "model.max_seq_len=1024", + "model.norm_eps=1e-5", + "model.num_kv_heads=4", + ] + + +def lora_llama2_test_config( + lora_attn_modules, + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> List[str]: + lora_attn_modules_str = "['" + "','".join([x for x in lora_attn_modules]) + "']" + return [ + # Note: we explicitly use _component_ so that we can also call + # config.instantiate directly for easier comparison + "model._component_=torchtune.models.llama2.lora_llama2", + f"model.lora_attn_modules={lora_attn_modules}", + f"model.apply_lora_to_mlp={apply_lora_to_mlp}", + f"model.apply_lora_to_output={apply_lora_to_output}", + "model.vocab_size=32000", + "model.num_layers=4", + "model.num_heads=16", + "model.embed_dim=256", + "model.max_seq_len=2048", + "model.norm_eps=1e-5", + "model.num_kv_heads=8", + f"model.lora_rank={lora_rank}", + f"model.lora_alpha={lora_alpha}", + "model.lora_dropout=0.0", + f"model.quantize_base={quantize_base}", + ] + + +def lora_llama3_test_config( + lora_attn_modules, + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> List[str]: + lora_attn_modules_str = "['" + "','".join([x for x in lora_attn_modules]) + "']" + return [ + # Note: we explicitly use _component_ so that we can also call + # config.instantiate directly for easier comparison + "model._component_=torchtune.models.llama3.lora_llama3", + f"model.lora_attn_modules={lora_attn_modules}", + f"model.apply_lora_to_mlp={apply_lora_to_mlp}", + f"model.apply_lora_to_output={apply_lora_to_output}", + "model.vocab_size=128_256", + "model.num_layers=2", + "model.num_heads=8", + "model.embed_dim=64", + "model.max_seq_len=1024", + "model.norm_eps=1e-5", + "model.num_kv_heads=4", + f"model.lora_rank={lora_rank}", + f"model.lora_alpha={lora_alpha}", + "model.lora_dropout=0.0", + f"model.quantize_base={quantize_base}", + ] + + +def write_hf_ckpt_config(ckpt_dir: str): + config = { + "hidden_size": 256, + "num_attention_heads": 16, + "num_key_value_heads": 8, + } + config_file = Path.joinpath(Path(ckpt_dir), "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + +MODEL_TEST_CONFIGS = { + "llama2": llama2_test_config(), + "llama3": llama3_test_config(), + "llama2_lora": lora_llama2_test_config( + lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=False, + apply_lora_to_output=False, + lora_rank=8, + lora_alpha=16, + ), + "llama2_qlora": lora_llama2_test_config( + lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + lora_rank=8, + lora_alpha=16, + quantize_base=True, + ), + "llama3_lora": lora_llama3_test_config( + lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=False, + apply_lora_to_output=False, + lora_rank=8, + lora_alpha=16, + ), +} diff --git a/training/tests/regression_tests/test_llama2_7b.py b/training/tests/regression_tests/test_llama2_7b.py new file mode 100644 index 0000000..758897d --- /dev/null +++ b/training/tests/regression_tests/test_llama2_7b.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +import runpy +import sys +from pathlib import Path + +import pytest +import torchtune +from tests.common import TUNE_PATH +from tests.test_utils import CKPT_MODEL_PATHS, gpu_test + + +CKPT = "llama2_7b" + +# TODO: remove this once we have eval configs exposed properly +pkg_path = Path(torchtune.__file__).parent.absolute() +EVAL_CONFIG_PATH = Path.joinpath( + pkg_path, "_cli", "eval_configs", "default_eval_config.yaml" +) + + +@gpu_test(gpu_count=2) +class TestLoRA7BDistributedFinetuneEval: + @pytest.mark.slow_integration_test + def test_finetune_and_eval(self, tmpdir, caplog, monkeypatch): + + ckpt_path = Path(CKPT_MODEL_PATHS[CKPT]) + ckpt_dir = ckpt_path.parent + + # Run on prod LoRA FT config but with only 10 steps for now + ft_cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed + --config llama2/7B_lora \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + max_steps_per_epoch=10 \ + """.split() + + monkeypatch.setattr(sys, "argv", ft_cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + eval_cmd = f""" + tune run eleuther_eval \ + --config eleuther_eval \ + output_dir={tmpdir} \ + checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{tmpdir}' \ + checkpointer.checkpoint_files=[torchtune_model_0.pt] \ + checkpointer.output_dir={tmpdir} \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + tasks=['truthfulqa_mc2'] + limit=10 \ + device=cuda \ + """.split() + monkeypatch.setattr(sys, "argv", eval_cmd) + with pytest.raises(SystemExit): + runpy.run_path(TUNE_PATH, run_name="__main__") + + err_log = caplog.messages[-1] + log_search_results = re.search(r"'acc,none': (\d+\.\d+)", err_log) + assert log_search_results is not None + acc_result = float(log_search_results.group(1)) + assert acc_result >= 0.4 diff --git a/training/tests/test_import_recipes.py b/training/tests/test_import_recipes.py new file mode 100644 index 0000000..83ee76d --- /dev/null +++ b/training/tests/test_import_recipes.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + + +def test_import_recipes(): + with pytest.raises( + ModuleNotFoundError, match="The torchtune recipes directory isn't a package" + ): + import recipes # noqa diff --git a/training/tests/test_profiler.py b/training/tests/test_profiler.py new file mode 100644 index 0000000..96423d3 --- /dev/null +++ b/training/tests/test_profiler.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import pytest +import torch +from omegaconf import DictConfig, OmegaConf +from torch._C._profiler import _ExperimentalConfig +from torchtune import config +from torchtune.utils import ( + DEFAULT_PROFILE_DIR, + DEFAULT_PROFILER_ACTIVITIES, + DEFAULT_SCHEDULE, + DEFAULT_TRACE_OPTS, + DummyProfiler, + PROFILER_KEY, +) + +# Disable logging otherwise output will be very verbose +logging.basicConfig(level=logging.ERROR) + +PROFILER_ATTRS = [ + "activities", + "profile_memory", + "with_stack", + "record_shapes", + "with_flops", +] + + +@pytest.fixture +def profiler_cfg(): + return """ +profiler: + enabled: True + cpu: True + cuda: True + profile_memory: False + with_stack: False + record_shapes: True + with_flops: True + wait_steps: 3 + warmup_steps: 1 + active_steps: 1 + num_cycles: 0 +""" + + +# This is a reference implementation of a profiler setup method to be defined within a `recipe`. +# A version of this lives in `torch.utils._profiler` but is not exported as the public API. +# Rather, the user is expected to define their own high-level setup function that parses the `cfg` +# and call a user-facing profiler setup function (e.g. `setup_torch_profiler`). +def _setup_profiler( + cfg_profiler: DictConfig, return_cfg: bool = False +) -> torch.profiler.profile: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (DictConfig): `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) + return_cfg (bool): Doesn't seem to be used. Default False. + + Returns: + profiler: torch.profiler.profile | DummyProfiler - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + return profiler, profiler_cfg + + +@pytest.fixture +def reference_profiler_basic(): + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=3, warmup=1, active=1, repeat=0), + profile_memory=False, + with_stack=False, + record_shapes=True, + with_flops=True, + ) + + +@pytest.fixture +def reference_profiler_full(): + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=3, warmup=1, active=1, repeat=0), + profile_memory=True, + with_stack=True, + record_shapes=True, + with_flops=True, + experimental_config=_ExperimentalConfig(verbose=True), + ) + + +def check_profiler_attrs(profiler, ref_profiler): + for attr in PROFILER_ATTRS: + assert getattr(profiler, attr) == getattr(ref_profiler, attr) + + +def check_schedule(schedule, ref_schedule, num_steps=10): + ref_steps = [ref_schedule(i) for i in range(num_steps)] + test_steps = [schedule(i) for i in range(num_steps)] + assert ref_steps == test_steps + + +def test_instantiate_basic(profiler_cfg, reference_profiler_basic): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + profiler, updated_cfg = _setup_profiler(cfg) + + check_profiler_attrs(profiler, reference_profiler_basic) + + ref_schedule = torch.profiler.schedule( + wait=updated_cfg["wait_steps"], + warmup=updated_cfg["warmup_steps"], + active=updated_cfg["active_steps"], + repeat=updated_cfg["num_cycles"], + ) + check_schedule(profiler.schedule, ref_schedule) + + +def test_instantiate_full(profiler_cfg, reference_profiler_full): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Check `setup` automatically overrides `with_stack` and `record_shapes` when profile_memory is True and adds + # experimental_config, which is needed for stack exporting (see comments in `setup_torch_profiler`) + cfg.profile_memory = True + cfg.with_stack = False + cfg.record_shapes = False + profiler, updated_cfg = _setup_profiler(cfg) + + check_profiler_attrs(profiler, reference_profiler_full) + assert profiler.experimental_config is not None + assert updated_cfg.with_stack is True + assert updated_cfg.record_shapes is True + + +def test_schedule_setup(profiler_cfg, reference_profiler_basic): + + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test that after removing schedule, setup method will implement default schedule + _ = [cfg.pop(k) for k in DEFAULT_SCHEDULE.keys()] + profiler, updated_cfg = _setup_profiler(cfg) + test_schedule = profiler.schedule + ref_schedule = torch.profiler.schedule( + wait=DEFAULT_SCHEDULE["wait_steps"], + warmup=DEFAULT_SCHEDULE["warmup_steps"], + active=DEFAULT_SCHEDULE["active_steps"], + repeat=DEFAULT_SCHEDULE["num_cycles"], + ) + check_schedule(ref_schedule, test_schedule) + + # Check cfg is updated correctly + for k in DEFAULT_SCHEDULE.keys(): + assert updated_cfg[k] == DEFAULT_SCHEDULE[k] + + # Test missing key is automatically set to default + for k in DEFAULT_SCHEDULE.keys(): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + cfg.pop(k) + profiler, updated_cfg = _setup_profiler(cfg) + assert updated_cfg[k] == DEFAULT_SCHEDULE[k] + + +def test_default_activities(profiler_cfg): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test setup automatically adds CPU + CUDA tracing if neither CPU nor CUDA is specified + cfg.pop("cpu") + cfg.pop("cuda") + profiler, updated_cfg = _setup_profiler(cfg) + assert profiler.activities == DEFAULT_PROFILER_ACTIVITIES + assert updated_cfg.cpu is True + assert updated_cfg.cuda is True + + +def test_default_output_dir(profiler_cfg): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test cfg output_dir is set correctly + if cfg.get("output_dir", None) is not None: + cfg.pop("output_dir") + _, updated_cfg = _setup_profiler(cfg, return_cfg=True) + assert updated_cfg.output_dir == DEFAULT_PROFILE_DIR + + +def test_default_trace_opts(profiler_cfg): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test missing profiler options are set to defaults + cfg.pop("profile_memory") + cfg.pop("with_stack") + cfg.pop("record_shapes") + cfg.pop("with_flops") + profiler, updated_cfg = _setup_profiler(cfg) + check_profiler_attrs( + profiler, + torch.profiler.profile( + activities=DEFAULT_PROFILER_ACTIVITIES, **DEFAULT_TRACE_OPTS + ), + ) + for k in ["profile_memory", "with_stack", "record_shapes", "with_flops"]: + assert updated_cfg[k] == DEFAULT_TRACE_OPTS[k] + + +def test_dummy_profiler(profiler_cfg): + + # Test missing `profile` key returns fake profiler + cfg = OmegaConf.create(profiler_cfg) + cfg.pop(PROFILER_KEY) + profiler, _ = _setup_profiler(cfg) + assert isinstance(profiler, DummyProfiler) + + # Test that disabled profiler creates fake profiler + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + cfg.enabled = False + profiler, _ = _setup_profiler(cfg) + assert isinstance(profiler, DummyProfiler) + + # Test that fake_profiler.step() does nothing both when used as context manager and as standalone object + with profiler as prof: + prof.step() + + # Additional DummyProfiler no-ops when used as object and not context + assert profiler.step() is None + assert profiler.start() is None + assert profiler.stop() is None diff --git a/training/tests/test_utils.py b/training/tests/test_utils.py new file mode 100644 index 0000000..ddbc88f --- /dev/null +++ b/training/tests/test_utils.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import re +import sys +import unittest +from contextlib import contextmanager +from functools import partial +from io import StringIO +from pathlib import Path +from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union + +import pytest + +import torch +from torch import nn +from torchtune.data import ChatFormat, Message, PromptTemplate, truncate +from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms import Transform + +skip_if_cuda_not_available = unittest.skipIf( + not torch.cuda.is_available(), "CUDA is not available" +) + +CKPT_MODEL_PATHS = { + "llama2_tune": "/tmp/test-artifacts/small-ckpt-tune-03082024.pt", + "llama2_meta": "/tmp/test-artifacts/small-ckpt-meta-03082024.pt", + "llama2_hf": "/tmp/test-artifacts/small-ckpt-hf-03082024.pt", + "llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt", + "llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt", + "llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt", +} + +TOKENIZER_PATHS = { + "llama2": "/tmp/test-artifacts/tokenizer.model", + "llama3": "/tmp/test-artifacts/tokenizer_llama3.model", +} + +# Taken from Open-Orca/SlimOrca-Dedup on Hugging Face: +# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup +CHAT_SAMPLE = { + "system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950 + "user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950 + "assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950 +} + +MESSAGE_SAMPLE_TRAIN_ON_INPUT = [ + Message( + role="system", + content=CHAT_SAMPLE["system"], + ), + Message( + role="user", + content=CHAT_SAMPLE["user"], + ), + Message( + role="assistant", + content=CHAT_SAMPLE["assistant"], + ), +] + +MESSAGE_SAMPLE = [ + Message(role="system", content=CHAT_SAMPLE["system"], masked=True), + Message(role="user", content=CHAT_SAMPLE["user"], masked=True), + Message( + role="assistant", + content=CHAT_SAMPLE["assistant"], + ), +] + + +class DummyTokenizer(ModelTokenizer, Transform): + def __init__(self, max_seq_len: Optional[int] = None): + self.max_seq_len = max_seq_len + + def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]: + words = text.split() + tokens = [len(word) for word in words] + if add_bos: + tokens = [self.bos_id] + tokens + if add_eos: + tokens = tokens + [self.eos_id] + return tokens + + def tokenize_messages( + self, + messages: List[Message], + ) -> Tuple[List[int], List[bool]]: + """ + A simplified version of Llama2Tokenizer's ``tokenize_messages`` for testing purposes. + """ + start_of_turn = True + end_of_turn = False + tokenized_messages = [] + mask = [] + for message in messages: + # If assistant message, this is the end of a turn + end_of_turn = message.role == "assistant" + + # Prepend BOS on start of new turns + if start_of_turn: + tokenized_messages.append(self.bos_id) + mask.append(message.masked) + + # Tokenize current message, append with masks + tokens = [] + for item in message.content: + if item["type"] == "text": + tokens = tokens + self.encode( + item["content"], + add_bos=False, + add_eos=False, + ) + elif item["type"] == "image": + tokens = tokens + [self.image_id] + + tokenized_messages.extend(tokens) + mask.extend([message.masked] * len(tokens)) + + # If assistant message, append EOS at end + if end_of_turn: + tokenized_messages.append(self.eos_id) + mask.append(message.masked) + end_of_turn = False + start_of_turn = True + else: + start_of_turn = False + + # Break out early if we reach max_seq_len + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + break + + # Finally, truncate if necessary + if self.max_seq_len: + tokenized_messages = truncate( + tokenized_messages, self.max_seq_len, self.eos_id + ) + mask = truncate(mask, self.max_seq_len, message.masked) + + return tokenized_messages, mask + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample + + @property + def eos_id(self): + return -1 + + @property + def bos_id(self): + return 0 + + @property + def image_id(self): + return -2 + + +class DummyChatFormat(ChatFormat): + + B_SYS, E_SYS = "System:\n", "\n" + B_INST, E_INST = "User:\n", "\nAssistant:\n" + B_ASST, E_ASST = "", "" + system = f"{B_SYS}{{content}}{E_SYS}" + user = f"{B_INST}{{content}}{E_INST}" + assistant = f"{B_ASST}{{content}}{E_ASST}" + + @classmethod + def format( + cls, + messages, + ): + formats = {"system": cls.system, "user": cls.user, "assistant": cls.assistant} + formatted_dialogue = [] + for message in messages: + content = formats.get(message.role).format( + content=message.content[0]["content"] + ) + formatted_dialogue.append( + Message(role=message.role, content=content, masked=message.masked), + ) + return formatted_dialogue + + +DummyPromptTemplate = partial( + PromptTemplate, + template={ + "system": ("System:\n", "\n"), + "user": ("User:\n", "\n"), + "assistant": ("Assistant:\n", "\n"), + }, +) + + +def get_assets_path(): + return Path(__file__).parent / "assets" + + +def fixed_init_tensor( + shape: torch.Size, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: torch.dtype = torch.float, +): + """ + Utility for generating deterministic tensors of a given shape. In general stuff + like torch.ones, torch.eye, etc can result in trivial outputs. This utility + generates a range tensor [min_val, max_val) of a specified dtype, applies + a sine function if nonlinear=True, then reshapes to the appropriate shape. + """ + n_elements = math.prod(shape) + step_size = (max_val - min_val) / n_elements + x = torch.arange(min_val, max_val, step_size, dtype=dtype) + x = x.reshape(shape) + if nonlinear: + return torch.sin(x) + return x + + +@torch.no_grad +def fixed_init_model( + model: nn.Module, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: Optional[torch.dtype] = None, +): + """ + This utility initializes all parameters of a model deterministically using the + function fixed_init_tensor above. See that docstring for details of each parameter. + """ + for _, param in model.named_parameters(): + param.copy_( + fixed_init_tensor( + param.shape, + min_val=min_val, + max_val=max_val, + nonlinear=nonlinear, + dtype=param.dtype if dtype is None else dtype, + ) + ) + + +def assert_expected( + actual: Any, + expected: Any, + rtol: float = 1e-5, + atol: float = 1e-8, + check_device: bool = True, +): + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + check_device=check_device, + msg=f"actual: {actual}, expected: {expected}", + ) + + +@contextmanager +def single_box_init(init_pg: bool = True): + env_vars = ["MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK", "RANK", "WORLD_SIZE"] + initial_os = {k: os.environ.get(k, None) for k in env_vars} + os.environ.get("MASTER_ADDR", None) + os.environ["MASTER_ADDR"] = "localhost" + # TODO: Don't hardcode ports as this could cause flakiness if tests execute + # in parallel. + os.environ["MASTER_PORT"] = str(12345) + os.environ["LOCAL_RANK"] = str(0) + os.environ["RANK"] = str(0) + os.environ["WORLD_SIZE"] = str(1) + if init_pg: + torch.distributed.init_process_group( + backend="gloo", + world_size=int(os.environ["WORLD_SIZE"]), + rank=int(os.environ["RANK"]), + ) + try: + yield + finally: + if init_pg: + torch.distributed.destroy_process_group() + for k in env_vars: + if initial_os.get(k) is None: + del os.environ[k] + else: + os.environ[k] = initial_os[k] + + +@contextmanager +def set_dtype(dtype: torch.dtype) -> Generator[None, None, None]: + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(old_dtype) + + +@contextmanager +def captured_output() -> Generator[Tuple[TextIO, TextIO], None, None]: + new_out, new_err = StringIO(), StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err + + +def gpu_test(gpu_count: int = 1): + """ + Annotation for GPU tests, skipping the test if the + required amount of GPU is not available + """ + message = f"Not enough GPUs to run the test: requires {gpu_count}" + local_gpu_count: int = torch.cuda.device_count() + return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message) + + +def get_loss_values_from_metric_logger(log_file_path: str) -> Dict[str, float]: + """ + Given an output directory containing metric logger .txt file, + parse the .txt and return a list of losses from each logged iteration. + """ + with open(log_file_path, "r") as f: + logs = f.read() + losses = [float(x) for x in re.findall(r"loss:(\d+\.\d+)", logs)] + return losses + + +def gen_log_file_name(tmpdir, suffix: Optional[str] = None) -> str: + """ + Take the tmpdir and just append a non-path version of it as the + filename, optionally adding specified suffix. This is used to + write metric logs to a deterministic file per test run. + E.g. /tmp/my/dir -> /tmp/my/dir/tmpmydir.txt + """ + filename = str(tmpdir) + str(tmpdir).replace("/", "") + if suffix: + filename += suffix + filename += ".txt" + return filename + + +def assert_dialogue_equal(actual, expected): + assert len(actual) == len(expected) + for i in range(len(actual)): + assert actual[i].role == expected[i].role + assert actual[i].text_content == expected[i].text_content diff --git a/training/tests/torchtune/__init__.py b/training/tests/torchtune/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/_cli/__init__.py b/training/tests/torchtune/_cli/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/_cli/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/_cli/test_cp.py b/training/tests/torchtune/_cli/test_cp.py new file mode 100644 index 0000000..722a848 --- /dev/null +++ b/training/tests/torchtune/_cli/test_cp.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys +from pathlib import Path + +import pytest +from tests.common import TUNE_PATH + + +class TestTuneCLIWithCopyScript: + """This class tests the `tune cp` command.""" + + @pytest.mark.parametrize("already_exists", (True, False)) + def test_copy_successful(self, capsys, monkeypatch, tmpdir, already_exists): + tmpdir_path = Path(tmpdir) + dest = tmpdir_path / "my_custom_finetune.yaml" + + if already_exists: + dest.touch() + + args = f"tune cp llama2/7B_full {dest}".split() + + monkeypatch.setattr(sys, "argv", args) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + + assert dest.exists(), f"Expected {dest} to exist" + assert f"Copied file to {dest}" in out + + def test_copy_successful_with_cwd_as_path(self, capsys, monkeypatch, tmpdir): + tmpdir_path = Path(tmpdir) + + # Needed so we can run test from tmpdir + tune_path_as_absolute = Path(TUNE_PATH).absolute() + + # Change cwd to tmpdir + monkeypatch.chdir(tmpdir_path) + + args = "tune cp llama2/7B_full .".split() + monkeypatch.setattr(sys, "argv", args) + runpy.run_path(str(tune_path_as_absolute), run_name="__main__") + + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + + dest = tmpdir_path / "7B_full.yaml" + + assert dest.exists() + assert "Copied file to ./7B_full.yaml" in out + + def test_copy_skips_when_dest_already_exists_and_no_clobber_is_true( + self, capsys, monkeypatch, tmpdir + ): + tmpdir_path = Path(tmpdir) + existing_file = tmpdir_path / "existing_file.yaml" + existing_file.touch() + + args = f"tune cp llama2/7B_full_low_memory {existing_file} -n".split() + + monkeypatch.setattr(sys, "argv", args) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + err = captured.err.rstrip("\n") + + assert err == "" + assert "not overwriting" in out + + def test_adds_correct_suffix_to_dest_when_no_suffix_is_provided( + self, capsys, monkeypatch, tmpdir + ): + tmpdir_path = Path(tmpdir) + dest = tmpdir_path / "my_custom_finetune" + + args = f"tune cp llama2/7B_full_low_memory {dest}".split() + + monkeypatch.setattr(sys, "argv", args) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + + assert dest.with_suffix(".yaml").exists(), f"Expected {dest} to exist" + assert f"Copied file to {dest}.yaml" in out + + @pytest.mark.parametrize( + "tune_command,expected_error_message", + [ + ( + "tune cp non_existent_recipe .", + "error: Invalid file name: non_existent_recipe. Try `tune ls` to see all available files to copy.", + ), + ( + "tune cp non_existent_config .", + "error: Invalid file name: non_existent_config. Try `tune ls` to see all available files to copy.", + ), + ( + "tune cp full_finetune_single_device /home/mr_bean/full_finetune_single_device.py", + "error: Cannot create regular file: '/home/mr_bean/full_finetune_single_device.py'. No such file or directory.", + ), + ( + "tune cp", + "error: the following arguments are required: file, destination", + ), + ], + ) + def test_copy_fails_when_given_invalid_recipe( + self, capsys, monkeypatch, tune_command, expected_error_message + ): + args = tune_command.split() + + monkeypatch.setattr(sys, "argv", args) + with pytest.raises(SystemExit): + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + err = captured.err.rstrip("\n") + + assert expected_error_message in err diff --git a/training/tests/torchtune/_cli/test_download.py b/training/tests/torchtune/_cli/test_download.py new file mode 100644 index 0000000..5dbd695 --- /dev/null +++ b/training/tests/torchtune/_cli/test_download.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys + +import pytest +from tests.common import TUNE_PATH + + +class TestTuneDownloadCommand: + """This class tests the `tune download` command.""" + + @pytest.fixture + def snapshot_download(self, mocker, tmpdir): + + from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + + yield mocker.patch( + "torchtune._cli.download.snapshot_download", + return_value=tmpdir, + # Side effects are iterated through on each call + side_effect=[ + GatedRepoError("test"), + RepositoryNotFoundError("test"), + mocker.DEFAULT, + ], + ) + + def test_download_calls_snapshot(self, capsys, monkeypatch, snapshot_download): + model = "meta-llama/Llama-2-7b" + testargs = f"tune download {model}".split() + monkeypatch.setattr(sys, "argv", testargs) + + # Call the first time and get GatedRepoError + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + out_err = capsys.readouterr() + assert ( + "Ignoring files matching the following patterns: *.safetensors" + in out_err.out + ) + assert ( + "It looks like you are trying to access a gated repository." in out_err.err + ) + + # Call the second time and get RepositoryNotFoundError + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + out_err = capsys.readouterr() + assert ( + "Ignoring files matching the following patterns: *.safetensors" + in out_err.out + ) + assert "not found on the Hugging Face Hub" in out_err.err + + # Call the third time and get the expected output + runpy.run_path(TUNE_PATH, run_name="__main__") + output = capsys.readouterr().out + assert "Ignoring files matching the following patterns: *.safetensors" in output + assert "Successfully downloaded model repo" in output + + # Make sure it was called twice + assert snapshot_download.call_count == 3 diff --git a/training/tests/torchtune/_cli/test_ls.py b/training/tests/torchtune/_cli/test_ls.py new file mode 100644 index 0000000..8ab70c9 --- /dev/null +++ b/training/tests/torchtune/_cli/test_ls.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import runpy +import sys + +from tests.common import TUNE_PATH + +from torchtune._recipe_registry import get_all_recipes + + +class TestTuneListCommand: + """This class tests the `tune ls` command.""" + + def test_ls_lists_all_recipes_and_configs(self, capsys, monkeypatch): + testargs = "tune ls".split() + + monkeypatch.setattr(sys, "argv", testargs) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + output = captured.out.rstrip("\n") + + for recipe in get_all_recipes(): + assert recipe.name in output + for config in recipe.configs: + assert config.name in output diff --git a/training/tests/torchtune/_cli/test_run.py b/training/tests/torchtune/_cli/test_run.py new file mode 100644 index 0000000..076ff07 --- /dev/null +++ b/training/tests/torchtune/_cli/test_run.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys + +import pytest + +from tests.common import TUNE_PATH + + +class TestTuneRunCommand: + def test_run_calls_distributed_run_for_distributed_recipe( + self, capsys, monkeypatch, mocker + ): + testargs = "tune run --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full".split() + + monkeypatch.setattr(sys, "argv", testargs) + distributed_run = mocker.patch("torchtune._cli.tune.Run._run_distributed") + runpy.run_path(TUNE_PATH, run_name="__main__") + distributed_run.assert_called_once() + + def test_run_calls_single_device_run_for_single_device_recipe( + self, capsys, monkeypatch, mocker + ): + testargs = "tune run full_finetune_single_device --config llama2/7B_full_single_device".split() + + monkeypatch.setattr(sys, "argv", testargs) + single_device_run = mocker.patch("torchtune._cli.tune.Run._run_single_device") + runpy.run_path(TUNE_PATH, run_name="__main__") + single_device_run.assert_called_once() + + def test_run_fails_when_called_with_distributed_args_for_single_device_recipe( + self, capsys, monkeypatch + ): + testargs = "tune run --nproc_per_node 4 full_finetune_single_device --config llama2/7B_full_single_device".split() + + monkeypatch.setattr(sys, "argv", testargs) + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr() + assert "does not support distributed training" in output.err + + def test_run_fails_when_config_not_passed_in(self, capsys, monkeypatch): + testargs = "tune run full_finetune_single_device batch_size=3".split() + + monkeypatch.setattr(sys, "argv", testargs) + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr() + assert "The '--config' argument is required" in output.err + + def test_run_succeeds_with_local_recipe_file_and_default_config( + self, capsys, monkeypatch, mocker + ): + testargs = "tune run my_custom_recipe.py --config llama2/7B_full".split() + monkeypatch.setattr(sys, "argv", testargs) + local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device") + runpy.run_path(TUNE_PATH, run_name="__main__") + local_file_run.assert_called_once() + + def test_run_calls_local_file_run_for_local_file_recipe( + self, capsys, monkeypatch, mocker + ): + testargs = "tune run my_custom_recipe.py --config custom_config.yaml".split() + + monkeypatch.setattr(sys, "argv", testargs) + local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device") + runpy.run_path(TUNE_PATH, run_name="__main__") + local_file_run.assert_called_once() diff --git a/training/tests/torchtune/_cli/test_tune.py b/training/tests/torchtune/_cli/test_tune.py new file mode 100644 index 0000000..ecee6cb --- /dev/null +++ b/training/tests/torchtune/_cli/test_tune.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys + +from tests.common import TUNE_PATH + + +class TestTuneCLI: + def test_tune_without_args_returns_help(self, capsys, monkeypatch): + testargs = ["tune"] + + monkeypatch.setattr(sys, "argv", testargs) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + output = captured.out.rstrip("\n") + + assert "Welcome to the torchtune CLI!" in output diff --git a/training/tests/torchtune/_cli/test_validate.py b/training/tests/torchtune/_cli/test_validate.py new file mode 100644 index 0000000..c215a2b --- /dev/null +++ b/training/tests/torchtune/_cli/test_validate.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys + +from pathlib import Path + +import pytest +from tests.common import TUNE_PATH + +ASSETS = Path(__file__).parent.parent.parent / "assets" + + +class TestTuneValidateCommand: + """This class tests the `tune validate` command.""" + + VALID_CONFIG_PATH = ASSETS / "valid_dummy_config.yaml" + INVALID_CONFIG_PATH = ASSETS / "invalid_dummy_config.yaml" + + def test_validate_good_config(self, capsys, monkeypatch): + args = f"tune validate {self.VALID_CONFIG_PATH}".split() + + monkeypatch.setattr(sys, "argv", args) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + + assert out == "Config is well-formed!" + + def test_validate_bad_config(self, monkeypatch, capsys): + args = f"tune validate {self.INVALID_CONFIG_PATH}".split() + + monkeypatch.setattr(sys, "argv", args) + with pytest.raises(SystemExit): + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + err = captured.err.rstrip("\n") + + assert "got an unexpected keyword argument 'dummy'" in err diff --git a/training/tests/torchtune/config/test_config_utils.py b/training/tests/torchtune/config/test_config_utils.py new file mode 100644 index 0000000..52139da --- /dev/null +++ b/training/tests/torchtune/config/test_config_utils.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +from io import StringIO +from unittest import mock + +import pytest +from omegaconf import OmegaConf +from torchtune.config._utils import ( + _get_component_from_path, + _merge_yaml_and_cli_args, + _remove_key_by_dotpath, + InstantiationError, + log_config, +) +from torchtune.utils.argparse import TuneRecipeArgumentParser + +_CONFIG = { + "a": 1, + "b": { + "_component_": 2, + "c": 3, + }, + "d": 4, + "f": 8, +} + + +class TestUtils: + def test_get_component_from_path(self): + good_paths = [ + "torchtune", # Test single module without dot + "torchtune.models", # Test dotpath for a module + "torchtune.models.llama2.llama2_7b", # Test dotpath for an object + ] + for path in good_paths: + _ = _get_component_from_path(path) + + # Test that a relative path fails + with pytest.raises(ValueError, match="Relative imports are not supported"): + _ = _get_component_from_path(".test") + # Test that a non-existent path fails + with pytest.raises( + InstantiationError, match="Error loading 'torchtune.models.dummy'" + ): + _ = _get_component_from_path("torchtune.models.dummy") + + @mock.patch("torchtune.utils.argparse.OmegaConf.load", return_value=_CONFIG) + def test_merge_yaml_and_cli_args(self, mock_load): + parser = TuneRecipeArgumentParser("test parser") + yaml_args, cli_args = parser.parse_known_args( + [ + "--config", + "test.yaml", + "b.c=4", # Test overriding a flat param in a component + "b=5", # Test overriding component path + "b.b.c=6", # Test nested dotpath + "d=6", # Test overriding a flat param + "e=7", # Test adding a new param + "~f", # Test removing a param + ] + ) + conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + assert conf.a == 1, f"a == {conf.a}, not 1 as set in the config." + assert ( + conf.b._component_ == 5 + ), f"b == {conf.b._component_}, not 5 as set in overrides." + assert conf.b.c == 4, f"b.c == {conf.b.c}, not 4 as set in overrides." + assert conf.b.b.c == 6, f"b.b.c == {conf.b.b.c}, not 6 as set in overrides." + assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides." + assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides." + assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides." + mock_load.assert_called_once() + + yaml_args, cli_args = parser.parse_known_args( + [ + "--config", + "test.yaml", + "b=5", # Test overriding component path but keeping other kwargs + ] + ) + conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + assert ( + conf.b._component_ == 5 + ), f"b == {conf.b._component_}, not 5 as set in overrides." + assert conf.b.c == 3, f"b.c == {conf.b.c}, not 3 as set in the config." + assert mock_load.call_count == 2 + + yaml_args, cli_args = parser.parse_known_args( + [ + "--config", + "test.yaml", + "b.c=5", # Test overriding kwarg but keeping component path + ] + ) + conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + assert ( + conf.b._component_ == 2 + ), f"b == {conf.b._component_}, not 2 as set in the config." + assert conf.b.c == 5, f"b.c == {conf.b.c}, not 5 as set in overrides." + assert mock_load.call_count == 3 + + yaml_args, cli_args = parser.parse_known_args( + [ + "--config", + "test.yaml", + "b", # Test invalid override + ] + ) + with pytest.raises( + ValueError, match="Command-line overrides must be in the form of key=value" + ): + _ = _merge_yaml_and_cli_args(yaml_args, cli_args) + + def test_log_config(self, capsys): + cfg = OmegaConf.create({"test": {"a": 1, "b": 2}}) + + # Create a logger and add a StreamHandler to it so we can patch the + # config logger and assert on logged strings + logger = logging.getLogger(__name__) + logger.setLevel("DEBUG") + stream = StringIO() + handler = logging.StreamHandler(stream) + logger.addHandler(handler) + + with mock.patch("torchtune.config._utils.get_logger", return_value=logger): + # Make sure rank 0 logs as expected + with mock.patch( + "torchtune.config._utils.get_world_size_and_rank", + return_value=(None, 0), + ): + log_config("test", cfg) + output = stream.getvalue().strip() + assert ( + "Running test with resolved config:\n\ntest:\n a: 1\n b: 2" + in output + ) + + # Clear the stream + stream.truncate(0) + stream.seek(0) + + # Make sure all other ranks do not log anything + with mock.patch( + "torchtune.config._utils.get_world_size_and_rank", + return_value=(None, 1), + ): + log_config("test", cfg) + output = stream.getvalue().strip() + assert not output + + def test_remove_key_by_dotpath(self): + # Test removing a component raises + cfg = copy.deepcopy(_CONFIG) + with pytest.raises( + ValueError, match="Removing components from CLI is not supported" + ): + _remove_key_by_dotpath(cfg, "b") + + # Test removing a top-level param + cfg = copy.deepcopy(_CONFIG) + _remove_key_by_dotpath(cfg, "a") + assert "a" not in cfg + + # Test removing a component param + cfg = copy.deepcopy(_CONFIG) + _remove_key_by_dotpath(cfg, "b.c") + assert "c" not in cfg["b"] + + # Test removing nested one level too deep fails + cfg = copy.deepcopy(_CONFIG) + with pytest.raises(TypeError, match="'int' object is not subscriptable"): + _remove_key_by_dotpath(cfg, "b.c.d") + + # Test removing non-existent param fails + cfg = copy.deepcopy(_CONFIG) + with pytest.raises(KeyError, match="'g'"): + _remove_key_by_dotpath(cfg, "g") diff --git a/training/tests/torchtune/config/test_instantiate.py b/training/tests/torchtune/config/test_instantiate.py new file mode 100644 index 0000000..d693126 --- /dev/null +++ b/training/tests/torchtune/config/test_instantiate.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from omegaconf import OmegaConf +from torchtune.config._errors import InstantiationError +from torchtune.config._instantiate import ( + _create_component, + _instantiate_node, + instantiate, +) +from torchtune.config._utils import _has_component +from torchtune.modules import RMSNorm + + +class TestInstantiate: + @pytest.fixture + def config(self): + s = """ + a: b + b: c + test: + _component_: torchtune.modules.RMSNorm + dim: 5 + """ + return OmegaConf.create(s) + + @pytest.fixture + def module(self): + return RMSNorm(dim=5, eps=1e-4) + + def get_dim(self, rms_norm: RMSNorm): + return rms_norm.scale.shape[0] + + def test_has_path(self, config): + assert _has_component(config.test) + assert not _has_component(config.a) + + def test_call_object(self, module): + obj = RMSNorm + args = (5,) + kwargs = {"eps": 1e-4} + actual = _create_component(obj, args, kwargs) + expected = module + assert isinstance(actual, RMSNorm) + assert self.get_dim(actual) == self.get_dim(expected) + assert actual.eps == expected.eps + + def test_instantiate_node(self, config, module): + actual = _instantiate_node(config.test) + expected = module + assert isinstance(actual, RMSNorm) + assert self.get_dim(actual) == self.get_dim(expected) + + with pytest.raises( + InstantiationError, match="Cannot instantiate specified object" + ): + _ = _instantiate_node(config.a) + + def test_instantiate(self, config, module): + actual = instantiate(config.test) + expected = module + assert isinstance(actual, RMSNorm) + assert self.get_dim(actual) == self.get_dim(expected) + + # Test passing in kwargs + actual = instantiate(config.test, eps=1e-4) + assert actual.eps == expected.eps + + # Test passing in positional args + del config.test.dim + actual = instantiate(config.test, 3) + assert self.get_dim(actual) == 3 diff --git a/training/tests/torchtune/config/test_parse.py b/training/tests/torchtune/config/test_parse.py new file mode 100644 index 0000000..9ce1436 --- /dev/null +++ b/training/tests/torchtune/config/test_parse.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace +from unittest.mock import patch + +import pytest +from torchtune import config + +_CONFIG = Namespace(a=1, b=2) + + +class TestParse: + def test_parse(self): + a = 1 + b = 3 + + @config.parse + def func(cfg): + assert cfg.a == a + assert cfg.b != b + + with patch( + "torchtune.config._parse.TuneRecipeArgumentParser.parse_known_args", + return_value=(_CONFIG, []), + ) as mock_parse_args: + with pytest.raises(SystemExit): + func() + mock_parse_args.assert_called_once() diff --git a/training/tests/torchtune/config/test_validate.py b/training/tests/torchtune/config/test_validate.py new file mode 100644 index 0000000..dd7e063 --- /dev/null +++ b/training/tests/torchtune/config/test_validate.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from omegaconf import OmegaConf +from torchtune import config +from torchtune.config._errors import ConfigError + +VALID_CONFIG_PATH = "tests/assets/valid_dummy_config.yaml" +INVALID_CONFIG_PATH = "tests/assets/invalid_dummy_config.yaml" + + +class TestValidate: + def test_validate(self): + conf = OmegaConf.load(VALID_CONFIG_PATH) + # Test a valid component + config.validate(conf) + # Test an invalid component + conf = OmegaConf.load(INVALID_CONFIG_PATH) + with pytest.raises(ConfigError) as excinfo: + config.validate(conf) + exc_config = excinfo.value + assert len(exc_config.errors) == 2 + for e in exc_config.errors: + assert isinstance(e, TypeError) + assert str(e) == "get_dtype got an unexpected keyword argument 'dummy'" diff --git a/training/tests/torchtune/data/test_chat_formats.py b/training/tests/torchtune/data/test_chat_formats.py new file mode 100644 index 0000000..70be9fc --- /dev/null +++ b/training/tests/torchtune/data/test_chat_formats.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE +from torchtune.data import ChatMLFormat, Llama2ChatFormat, Message, MistralChatFormat + + +class TestLlama2ChatFormat: + expected_dialogue = [ + Message( + role="user", + content="[INST] <>\nYou are an AI assistant. User will you give you a task. " + "Your goal is to complete the task as faithfully as you can. While performing " + "the task think step-by-step and justify your steps.\n<>\n\nPlease " + "briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " + "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " + "How about on an icy road? Well one father in Russia did just that, and recorded " + "the entire thing. To her credit, the child seemed to be doing a great job. " + "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" + "Summary: [/INST] ", + ), + Message( + role="assistant", + content="A father in Russia allowed his 8-year-old child to drive his car on an " + "icy road and recorded the event. The child appeared to be handling the situation well, " + "showcasing their driving skills despite the challenging conditions.", + ), + ] + + def test_format(self): + actual = Llama2ChatFormat.format(MESSAGE_SAMPLE) + assert_dialogue_equal(actual, self.expected_dialogue) + + +class TestMistralChatFormat: + expected_dialogue = [ + Message( + role="user", + content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " + "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " + "How about on an icy road? Well one father in Russia did just that, and recorded " + "the entire thing. To her credit, the child seemed to be doing a great job. " + "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" + "Summary: [/INST] ", + ), + Message( + role="assistant", + content="A father in Russia allowed his 8-year-old child to drive his car on an " + "icy road and recorded the event. The child appeared to be handling the situation well, " + "showcasing their driving skills despite the challenging conditions.", + ), + ] + + def test_format(self): + no_system_sample = MESSAGE_SAMPLE[1:] + actual = MistralChatFormat.format(no_system_sample) + assert_dialogue_equal(actual, self.expected_dialogue) + + def test_format_with_system_prompt_raises(self): + with pytest.raises( + ValueError, match="System prompts are not supported in MistralChatFormat" + ): + _ = MistralChatFormat.format(MESSAGE_SAMPLE) + + +class TestChatMLFormat: + expected_dialogue = [ + Message( + role="system", + content="<|im_start|>system\nYou are an AI assistant. User will you give you a task. " + "Your goal is to complete the task as faithfully as you can. While performing " + "the task think step-by-step and justify your steps.<|im_end|>\n", + ), + Message( + role="user", + content="<|im_start|>user\nPlease " + "briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " + "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " + "How about on an icy road? Well one father in Russia did just that, and recorded " + "the entire thing. To her credit, the child seemed to be doing a great job. " + "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" + "Summary:<|im_end|>\n", + ), + Message( + role="assistant", + content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an " + "icy road and recorded the event. The child appeared to be handling the situation well, " + "showcasing their driving skills despite the challenging conditions.<|im_end|>", + ), + ] + + def test_format(self): + actual = ChatMLFormat.format(MESSAGE_SAMPLE) + assert_dialogue_equal(actual, self.expected_dialogue) diff --git a/training/tests/torchtune/data/test_converters.py b/training/tests/torchtune/data/test_converters.py new file mode 100644 index 0000000..8c02656 --- /dev/null +++ b/training/tests/torchtune/data/test_converters.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from tests.test_utils import ( + assert_dialogue_equal, + CHAT_SAMPLE, + MESSAGE_SAMPLE, + MESSAGE_SAMPLE_TRAIN_ON_INPUT, +) +from torchtune.data import get_openai_messages, get_sharegpt_messages + + +class TestShareGPTToLlama2Messages: + samples = { + "conversations": [ + { + "from": "system", + "value": CHAT_SAMPLE["system"], + }, + { + "from": "human", + "value": CHAT_SAMPLE["user"], + }, + { + "from": "gpt", + "value": CHAT_SAMPLE["assistant"], + }, + ] + } + + def test_conversion(self): + converted_messages = get_sharegpt_messages(self.samples) + assert_dialogue_equal(converted_messages, MESSAGE_SAMPLE) + + def test_conversion_train_on_input(self): + converted_messages = get_sharegpt_messages(self.samples, train_on_input=True) + assert_dialogue_equal(converted_messages, MESSAGE_SAMPLE_TRAIN_ON_INPUT) + + +class TestOpenAIToLlama2Messages: + samples_1 = { + "id": "DUMMY", + "conversations": [ + { + "role": "system", + "content": CHAT_SAMPLE["system"], + }, + { + "role": "user", + "content": CHAT_SAMPLE["user"], + }, + { + "role": "assistant", + "content": CHAT_SAMPLE["assistant"], + }, + ], + } + + samples_2 = { + "id": "DUMMY", + "messages": [ + { + "role": "system", + "content": CHAT_SAMPLE["system"], + }, + { + "role": "user", + "content": CHAT_SAMPLE["user"], + }, + { + "role": "assistant", + "content": CHAT_SAMPLE["assistant"], + }, + ], + } + + def test_conversion_conversations_key(self): + converted_messages_1 = get_openai_messages(self.samples_1) + assert_dialogue_equal(converted_messages_1, MESSAGE_SAMPLE) + + def test_conversion_messages_key(self): + converted_messages_2 = get_openai_messages(self.samples_2) + assert_dialogue_equal(converted_messages_2, MESSAGE_SAMPLE) + + def test_conversion_conversations_key_train_on_input(self): + converted_messages_1 = get_openai_messages(self.samples_1, train_on_input=True) + assert_dialogue_equal(converted_messages_1, MESSAGE_SAMPLE_TRAIN_ON_INPUT) + + def test_conversion_messages_key_train_on_input(self): + converted_messages_2 = get_openai_messages(self.samples_2, train_on_input=True) + assert_dialogue_equal(converted_messages_2, MESSAGE_SAMPLE_TRAIN_ON_INPUT) diff --git a/training/tests/torchtune/data/test_data_utils.py b/training/tests/torchtune/data/test_data_utils.py new file mode 100644 index 0000000..1eb34b2 --- /dev/null +++ b/training/tests/torchtune/data/test_data_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from torchtune.data import Message, truncate, validate_messages + + +def test_truncate(): + tokens = [1, 2, 3, 4, -1] + + # Test no truncation + truncated_tokens = truncate( + tokens=tokens, + max_seq_len=5, + eos_id=-1, + ) + assert truncated_tokens == tokens + + masks = [True, True, False, True, False] + # Test truncated mask + truncated_masks = truncate(tokens=masks, max_seq_len=4, eos_id=False) + assert truncated_masks == [True, True, False, False] + + +def test_validate_messages(): + messages = [ + Message(role="system", content="hello"), + Message(role="user", content="hello"), + Message(role="assistant", content="world"), + ] + + # Test valid conversation with system + validate_messages(messages) + + # Test valid conversation without system + validate_messages(messages[1:]) + + # Test system not first + messages = [ + Message(role="user", content="hello"), + Message(role="system", content="hello"), + Message(role="assistant", content="world"), + ] + with pytest.raises( + ValueError, + match="System message at index 1 in messages, but system messages must come first", + ): + validate_messages(messages) + + # Test empty assistant message + messages = [ + Message(role="system", content="hello"), + Message(role="user", content="world"), + Message(role="assistant", content=""), + ] + validate_messages(messages) + + # Test single message + messages = [ + Message(role="user", content="hello"), + ] + with pytest.raises( + ValueError, match="Messages must be at least length 2, but got 1 messages" + ): + validate_messages(messages) + + # Test repeated user message + messages = [ + Message(role="user", content="hello"), + Message(role="user", content="world"), + Message(role="assistant", content="world"), + ] + with pytest.raises( + ValueError, match="Two consecutive user messages at index 1 and 0 in messages" + ): + validate_messages(messages) + + # Test assistant message comes first + messages = [ + Message(role="assistant", content="hello"), + Message(role="user", content="world"), + ] + with pytest.raises( + ValueError, + match="Assistant message before expected user message at index 0 in messages", + ): + validate_messages(messages) diff --git a/training/tests/torchtune/data/test_instruct_templates.py b/training/tests/torchtune/data/test_instruct_templates.py new file mode 100644 index 0000000..f30a70a --- /dev/null +++ b/training/tests/torchtune/data/test_instruct_templates.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data import AlpacaInstructTemplate + + +class TestAlpacaInstructTemplate: + samples = [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + }, + { + "instruction": "Evaluate this sentence for spelling and grammar mistakes", + "input": "He finnished his meal and left the resturant", + "output": "He finished his meal and left the restaurant.", + }, + ] + expected_prompts = [ + ( + "Below is an instruction that describes a task. Write a response that appropriately " + "completes the request.\n\n" + "### Instruction:\nGive three tips for staying healthy.\n\n" + "### Response:\n" + ), + ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n" + "### Input:\nHe finnished his meal and left the resturant\n\n" + "### Response:\n" + ), + ] + + template = AlpacaInstructTemplate() + + def test_format(self): + for sample, expected_prompt in zip(self.samples, self.expected_prompts): + actual = self.template.format(sample) + assert actual == expected_prompt + + def test_format_with_column_map(self): + column_map = {"instruction": "not_an_instruction", "input": "not_an_input"} + for sample, expected_prompt in zip(self.samples, self.expected_prompts): + modified_sample = sample.copy() + modified_sample["not_an_instruction"], modified_sample["not_an_input"] = ( + modified_sample["instruction"], + modified_sample["input"], + ) + del modified_sample["instruction"], modified_sample["input"] + + actual = self.template.format(modified_sample, column_map=column_map) + + assert actual == expected_prompt diff --git a/training/tests/torchtune/data/test_messages.py b/training/tests/torchtune/data/test_messages.py new file mode 100644 index 0000000..f1df608 --- /dev/null +++ b/training/tests/torchtune/data/test_messages.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import ( + assert_dialogue_equal, + CHAT_SAMPLE, + MESSAGE_SAMPLE, + MESSAGE_SAMPLE_TRAIN_ON_INPUT, +) +from torchtune.data._messages import ( + InputOutputToMessages, + JSONToMessages, + Message, + ShareGPTToMessages, +) + + +class TestMessage: + @pytest.fixture + def text_message(self): + return Message(role="user", content="hello world") + + @pytest.fixture + def image_message(self): + return Message( + role="user", + content=[ + {"type": "text", "content": "hello"}, + {"type": "image"}, + {"type": "text", "content": " world"}, + ], + ) + + def test_message_validation(self, text_message): + message = text_message + assert message.role == "user" + assert message.content == [{"type": "text", "content": "hello world"}] + + with pytest.raises( + ValueError, + match="Only assistant messages can be tool calls. Found role user in message: hello world", + ): + message = Message(role="user", content="hello world", ipython=True) + + with pytest.raises( + ValueError, + match="Media tokens in tool calls are not supported. Both are set in message: ", + ): + message = Message( + role="user", + content=[{"type": "image"}], + ipython=True, + ) + + def test_from_dict(self): + message = Message.from_dict({"role": "user", "content": "hello world"}) + assert message.role == "user" + assert message.content[0]["content"] == "hello world" + assert not message.masked + assert not message.ipython + assert message.eot + + def test_contains_media(self, text_message, image_message): + assert not text_message.contains_media + assert image_message.contains_media + + def test_text_content(self, text_message, image_message): + assert text_message.text_content == "hello world" + assert image_message.text_content == "hello world" + + +class TestInputOutputToMessages: + @pytest.fixture + def sample(self): + return { + "maybe_input": "hello world", + "maybe_output": "hello world", + } + + def test_call(self, sample): + transform = InputOutputToMessages( + column_map={"input": "maybe_input", "output": "maybe_output"} + ) + actual = transform(sample) + expected = [ + Message(role="user", content="hello world", masked=True, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["messages"], expected) + + def test_call_train_on_input(self, sample): + transform = InputOutputToMessages( + column_map={"input": "maybe_input", "output": "maybe_output"}, + train_on_input=True, + ) + actual = transform(sample) + expected = [ + Message(role="user", content="hello world", masked=False, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["messages"], expected) + + +class TestShareGPTToMessages: + samples = { + "conversations": [ + { + "from": "system", + "value": CHAT_SAMPLE["system"], + }, + { + "from": "human", + "value": CHAT_SAMPLE["user"], + }, + { + "from": "gpt", + "value": CHAT_SAMPLE["assistant"], + }, + ] + } + + def test_call(self): + transform = ShareGPTToMessages() + converted_messages = transform(self.samples) + assert_dialogue_equal(converted_messages["messages"], MESSAGE_SAMPLE) + + def test_call_train_on_input(self): + transform = ShareGPTToMessages(train_on_input=True) + converted_messages = transform(self.samples) + assert_dialogue_equal( + converted_messages["messages"], MESSAGE_SAMPLE_TRAIN_ON_INPUT + ) + + +class TestJSONToMessages: + samples = { + "messages": [ + { + "role": "system", + "content": CHAT_SAMPLE["system"], + }, + { + "role": "user", + "content": CHAT_SAMPLE["user"], + }, + { + "role": "assistant", + "content": CHAT_SAMPLE["assistant"], + }, + ], + } + + def test_call(self): + transform = JSONToMessages() + converted_messages = transform(self.samples) + assert_dialogue_equal(converted_messages["messages"], MESSAGE_SAMPLE) + + def test_call_train_on_input(self): + transform = JSONToMessages(train_on_input=True) + converted_messages = transform(self.samples) + assert_dialogue_equal( + converted_messages["messages"], MESSAGE_SAMPLE_TRAIN_ON_INPUT + ) diff --git a/training/tests/torchtune/data/test_prompt_templates.py b/training/tests/torchtune/data/test_prompt_templates.py new file mode 100644 index 0000000..8904c8f --- /dev/null +++ b/training/tests/torchtune/data/test_prompt_templates.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE +from torchtune.data import ( + ChatMLTemplate, + GrammarErrorCorrectionTemplate, + Message, + SummarizeTemplate, +) + + +class TestChatMLTemplate: + expected_dialogue = [ + Message( + role="system", + content="<|im_start|>system\nYou are an AI assistant. User will you give you a task. " + "Your goal is to complete the task as faithfully as you can. While performing " + "the task think step-by-step and justify your steps.<|im_end|>\n", + ), + Message( + role="user", + content="<|im_start|>user\nPlease " + "briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " + "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " + "How about on an icy road? Well one father in Russia did just that, and recorded " + "the entire thing. To her credit, the child seemed to be doing a great job. " + "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" + "Summary:<|im_end|>\n", + ), + Message( + role="assistant", + content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an " + "icy road and recorded the event. The child appeared to be handling the situation well, " + "showcasing their driving skills despite the challenging conditions.<|im_end|>", + ), + ] + + def test_format(self): + actual = ChatMLTemplate()(MESSAGE_SAMPLE) + assert_dialogue_equal(actual, self.expected_dialogue) + + +class TestGrammarErrorCorrectionTemplate: + samples = [ + { + "messages": [ + Message( + role="user", + content="Bitcoin is for $7,094 this morning, which CoinDesk says.", + ), + Message( + role="assistant", + content="Bitcoin goes for $7,094 this morning, according to CoinDesk.", + ), + ] + }, + { + "messages": [ + Message( + role="user", + content="Much many brands and sellers still in the market.", + ), + Message( + role="assistant", + content="Many brands and sellers still in the market.", + ), + ], + }, + ] + expected_prompts = [ + [ + Message( + role="user", + content="Correct this to standard English: Bitcoin is for $7,094 this morning, which CoinDesk says.\n" + "---\n" + "Corrected: ", + ), + Message( + role="assistant", + content="Bitcoin goes for $7,094 this morning, according to CoinDesk.", + ), + ], + [ + Message( + role="user", + content="Correct this to standard English: Much many brands and sellers still in the market.\n" + "---\n" + "Corrected: ", + ), + Message( + role="assistant", + content="Many brands and sellers still in the market.", + ), + ], + ] + + template = GrammarErrorCorrectionTemplate() + + def test_call(self): + for sample, expected_prompt in zip(self.samples, self.expected_prompts): + actual = self.template(sample["messages"]) + assert_dialogue_equal(actual, expected_prompt) + + +class TestSummarizeTemplate: + samples = [ + { + "messages": [ + Message( + role="user", + content="Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + ), + Message( + role="assistant", + content="Amanda baked cookies and will bring Jerry some tomorrow.", + ), + ], + }, + { + "messages": [ + Message( + role="user", + content="Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great", # noqa: B950 + ), + Message( + role="assistant", + content="Olivia and Olivier are voting for liberals in this election.", + ), + ], + }, + ] + expected_prompts = [ + [ + Message( + role="user", + content="Summarize this dialogue:\n" + "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)\n" + "---\n" + "Summary:\n", + ), + Message( + role="assistant", + content="Amanda baked cookies and will bring Jerry some tomorrow.", + ), + ], + [ + Message( + role="user", + content="Summarize this dialogue:\n" + "Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great\n" + "---\n" + "Summary:\n", + ), + Message( + role="assistant", + content="Olivia and Olivier are voting for liberals in this election.", + ), + ], + ] + + template = SummarizeTemplate() + + def test_call(self): + for sample, expected_prompt in zip(self.samples, self.expected_prompts): + actual = self.template(sample["messages"]) + assert_dialogue_equal(actual, expected_prompt) diff --git a/training/tests/torchtune/datasets/__init__.py b/training/tests/torchtune/datasets/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/datasets/test_alpaca_dataset.py b/training/tests/torchtune/datasets/test_alpaca_dataset.py new file mode 100644 index 0000000..c062a76 --- /dev/null +++ b/training/tests/torchtune/datasets/test_alpaca_dataset.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from tests.test_utils import DummyTokenizer +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import alpaca_cleaned_dataset, alpaca_dataset + + +class TestAlpacaDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + @patch("torchtune.datasets._instruct.load_dataset") + def test_label_no_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is not masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + } + ] + ) + + alpaca_ds = alpaca_dataset(tokenizer=tokenizer) + input, labels = alpaca_ds[0]["tokens"], alpaca_ds[0]["labels"] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert CROSS_ENTROPY_IGNORE_IDX not in labels + + @patch("torchtune.datasets._instruct.load_dataset") + def test_label_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + } + ] + ) + + alpaca_ds = alpaca_dataset(tokenizer=tokenizer, train_on_input=False) + + # Extract the prompt and tokenize it; we'll need this to test whether we're masking the + # input correctly + sample = alpaca_ds._data[0] + prompt = alpaca_ds.template.format(sample=sample) + encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) + + # Generate the input and labels + input, labels = alpaca_ds[0]["tokens"], alpaca_ds[0]["labels"] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) + + @patch("torchtune.datasets._instruct.load_dataset") + def test_alpaca_clean(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is not masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + } + ] + ) + + alpaca_ds = alpaca_cleaned_dataset(tokenizer=tokenizer) + input, labels = alpaca_ds[0]["tokens"], alpaca_ds[0]["labels"] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert CROSS_ENTROPY_IGNORE_IDX not in labels diff --git a/training/tests/torchtune/datasets/test_chat_dataset.py b/training/tests/torchtune/datasets/test_chat_dataset.py new file mode 100644 index 0000000..a4a9278 --- /dev/null +++ b/training/tests/torchtune/datasets/test_chat_dataset.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest import mock + +import pytest +from tests.test_utils import DummyChatFormat, DummyTokenizer +from torchtune.data import Message +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import ChatDataset + + +class TestChatDataset: + @pytest.fixture + def chat_format(self): + return DummyChatFormat + + @pytest.fixture + def dialogue(self): + return [ + { + "dialogue": [ + Message.from_dict( + { + "role": "system", + "content": "You are an AI assistant.", + "masked": True, + } + ), + Message.from_dict( + { + "role": "user", + "content": "What is the meaning of life?", + "masked": True, + } + ), + Message.from_dict( + { + "role": "assistant", + "content": "The meaning of life is 42.", + "masked": False, + } + ), + Message.from_dict( + { + "role": "user", + "content": "That's ridiculous.", + "masked": True, + } + ), + Message.from_dict( + {"role": "assistant", "content": "I agree.", "masked": False} + ), + ], + }, + ] + + @mock.patch("torchtune.datasets._chat.load_dataset") + def test_get_item(self, mock_load_dataset, chat_format, dialogue): + mock_load_dataset.return_value = dialogue + expected_tokenized_prompts = [ + [ + 0, + 7, + 3, + 3, + 2, + 2, + 10, + 5, + 4, + 2, + 3, + 7, + 2, + 5, + 10, + 3, + 7, + 2, + 4, + 2, + 3, + -1, + 0, + 5, + 6, + 11, + 10, + 1, + 6, + -1, + ] + ] + prompt_lengths = (15, 5) + expected_labels = [ + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + + [3, 7, 2, 4, 2, 3, -1] + + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + + [1, 6, -1] + ] + ds = ChatDataset( + tokenizer=DummyTokenizer(), + source="iam/agoofy/goober", + convert_to_messages=lambda x, y: x["dialogue"], + chat_format=chat_format, + max_seq_len=100, + train_on_input=False, + ) + assert len(ds) == 1 + mock_load_dataset.assert_called_once() + + prompt, label = ds[0]["tokens"], ds[0]["labels"] + assert prompt == expected_tokenized_prompts[0] + assert label == expected_labels[0] diff --git a/training/tests/torchtune/datasets/test_cnn_dailymail_dataset.py b/training/tests/torchtune/datasets/test_cnn_dailymail_dataset.py new file mode 100644 index 0000000..c6608ce --- /dev/null +++ b/training/tests/torchtune/datasets/test_cnn_dailymail_dataset.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from unittest.mock import patch + +import pytest + +from tests.test_utils import DummyTokenizer + +from torchtune.datasets import cnn_dailymail_articles_dataset + + +class TestCNNDailyMailArticlesDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + @patch("torchtune.datasets._text_completion.load_dataset") + @pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) + def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): + # Sample data from CNN / DailyMail dataset + load_dataset.return_value = [ + { + "article": "(CNN) -- An American woman died aboard a cruise ship " + "that docked at Rio de Janeiro on Tuesday, the same ship on which " + "86 passengers previously fell ill, according to the state-run " + "Brazilian news agency, Agencia Brasil. The American tourist died " + "aboard the MS Veendam, owned by cruise operator Holland America. " + "Federal Police told Agencia Brasil that forensic doctors were " + "investigating her death. The ship's doctors told police that the " + "woman was elderly and suffered from diabetes and hypertension, " + "according the agency. The other passengers came down with diarrhea " + "prior to her death during an earlier part of the trip, the ship's " + "doctors said. The Veendam left New York 36 days ago for a South " + "America tour.", + } + ] + ds = cnn_dailymail_articles_dataset( + tokenizer=tokenizer, + max_seq_len=max_seq_len, + ) + input, label = ds[0]["tokens"], ds[0]["labels"] + assert len(input) <= max_seq_len + assert len(label) <= max_seq_len + assert len(input) == len(label) + assert input[0] == tokenizer.bos_id diff --git a/training/tests/torchtune/datasets/test_concat_dataset.py b/training/tests/torchtune/datasets/test_concat_dataset.py new file mode 100644 index 0000000..32ecc3b --- /dev/null +++ b/training/tests/torchtune/datasets/test_concat_dataset.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from datasets import Dataset +from torchtune.datasets._concat import ConcatDataset + + +class TestConcatDataset: + @pytest.fixture + def datasets(self): + ds1 = Dataset.from_list([{"data": f"ds1_{i}"} for i in range(4)]) + ds2 = Dataset.from_list([{"data": f"ds2_{i}"} for i in range(8)]) + ds3 = Dataset.from_list([{"data": f"ds3_{i}"} for i in range(15)]) + ds4 = Dataset.from_list([{"data": f"ds4_{i}"} for i in range(16)]) + ds5 = Dataset.from_list([{"data": f"ds5_{i}"} for i in range(23)]) + ds6 = Dataset.from_list([{"data": f"ds6_{i}"} for i in range(42)]) + return [ds1, ds2, ds3, ds4, ds5, ds6] + + def test_length(self, datasets): + """Test the correct computation of total length""" + multi_dataset = ConcatDataset(datasets) + + # sum of individual datasets lengths + expected_length = 4 + 8 + 15 + 16 + 23 + 42 # 108 + assert len(multi_dataset) == expected_length + + def test_getitem(self, datasets): + """Test item retrieval across dataset boundaries""" + multi_dataset = ConcatDataset(datasets) + + # Testing indices across different datasets + assert multi_dataset[-1] is None # Index out of range + assert multi_dataset[0] == {"data": "ds1_0"} + assert multi_dataset[3] == {"data": "ds1_3"} + assert multi_dataset[4] == {"data": "ds2_0"} + assert multi_dataset[10] == {"data": "ds2_6"} + assert multi_dataset[20] == {"data": "ds3_8"} + assert multi_dataset[35] == {"data": "ds4_8"} + assert multi_dataset[50] == {"data": "ds5_7"} + assert multi_dataset[70] == {"data": "ds6_4"} + assert multi_dataset[90] == {"data": "ds6_24"} + assert multi_dataset[108] is None # Index out of range + + def test_invalid_index_type(self, datasets): + """Test handling of invalid index types""" + multi_dataset = ConcatDataset(datasets) + + with pytest.raises(TypeError): + multi_dataset["invalid_type"] # Non-integer index diff --git a/training/tests/torchtune/datasets/test_finetune_dataset.py b/training/tests/torchtune/datasets/test_finetune_dataset.py new file mode 100644 index 0000000..56bd72d --- /dev/null +++ b/training/tests/torchtune/datasets/test_finetune_dataset.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping +from unittest import mock + +import pytest +from tests.test_utils import DummyPromptTemplate, DummyTokenizer +from torchtune.data import Message +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets._sft import SFTDataset +from torchtune.modules.transforms import Transform + + +class ToDummyMessages(Transform): + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + dialogue = sample["dialogue"] + messages = [Message.from_dict(d) for d in dialogue] + return {"messages": messages} + + +class TestSFTDataset: + @pytest.fixture + def dialogue(self): + return [ + { + "dialogue": [ + { + "role": "system", + "content": "You are an AI assistant.", + "masked": True, + }, + { + "role": "user", + "content": "What is the meaning of life?", + "masked": True, + }, + { + "role": "assistant", + "content": "The meaning of life is 42.", + "masked": False, + }, + { + "role": "user", + "content": "That's ridiculous.", + "masked": True, + }, + {"role": "assistant", "content": "I agree.", "masked": False}, + ], + }, + ] + + @mock.patch("torchtune.datasets._sft.load_dataset") + def test_get_item(self, mock_load_dataset, dialogue): + mock_load_dataset.return_value = dialogue + expected_tokenized_prompts = [ + [ + 0, + 7, + 3, + 3, + 2, + 2, + 10, + 5, + 4, + 2, + 3, + 7, + 2, + 5, + 10, + 3, + 7, + 2, + 4, + 2, + 3, + -1, + 0, + 5, + 6, + 11, + 10, + 1, + 6, + -1, + ] + ] + prompt_lengths = (14, 4) + expected_labels = [ + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + + [10, 3, 7, 2, 4, 2, 3, -1] + + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + + [10, 1, 6, -1] + ] + ds = SFTDataset( + source="iam/agoofy/goober", + message_transform=ToDummyMessages(), + model_transform=DummyTokenizer(), + prompt_template=DummyPromptTemplate(), + ) + assert len(ds) == 1 + mock_load_dataset.assert_called_once() + prompt, label = ds[0]["tokens"], ds[0]["labels"] + assert prompt == expected_tokenized_prompts[0] + assert label == expected_labels[0] diff --git a/training/tests/torchtune/datasets/test_grammar_dataset.py b/training/tests/torchtune/datasets/test_grammar_dataset.py new file mode 100644 index 0000000..63d6c6b --- /dev/null +++ b/training/tests/torchtune/datasets/test_grammar_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from tests.test_utils import DummyTokenizer +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX + +from torchtune.datasets import grammar_dataset + + +class TestGrammarDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + @patch("torchtune.datasets._sft.load_dataset") + def test_label_no_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is not masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", + "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", + } + ] + ) + + grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True) + input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"] + + assert input == [ + 0, + 7, + 4, + 2, + 8, + 8, + 7, + 2, + 3, + 6, + 4, + 8, + 5, + 8, + 5, + 3, + 10, + 7, + 4, + 3, + 6, + 4, + 8, + 9, + 2, + 9, + -1, + ] + assert labels == input + + @patch("torchtune.datasets._sft.load_dataset") + def test_label_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", + "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", + } + ] + ) + + grammar_ds = grammar_dataset(model_transform=tokenizer) + + # Generate the input and labels + input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"] + + assert input == [ + 0, + 7, + 4, + 2, + 8, + 8, + 7, + 2, + 3, + 6, + 4, + 8, + 5, + 8, + 5, + 3, + 10, + 7, + 4, + 3, + 6, + 4, + 8, + 9, + 2, + 9, + -1, + ] + # Check that the input is masked + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 17 diff --git a/training/tests/torchtune/datasets/test_instruct_dataset.py b/training/tests/torchtune/datasets/test_instruct_dataset.py new file mode 100644 index 0000000..c15c4d4 --- /dev/null +++ b/training/tests/torchtune/datasets/test_instruct_dataset.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest import mock + +from datasets import Dataset +from tests.test_utils import DummyTokenizer +from torchtune.data import InstructTemplate +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import InstructDataset + + +def dummy_transform(sample): + sample["input"] = sample["input"] + " asdfghjkl; " + sample["instruction"] = sample["instruction"] + " asdfghjkl; " + return sample + + +class DummyTemplate(InstructTemplate): + template = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse: " + + @classmethod + def format(cls, sample, column_map): + return cls.template.format(**sample) + + +class TestInstructDataset: + template = DummyTemplate + expected_tokenized_prompts = [ + [ + 0, + 12, + 4, + 2, + 3, + 2, + 12, + 10, + 6, + 4, + 2, + 3, + 2, + 6, + 10, + 9, + 1, + 5, + 4, + 4, + 3, + 6, + 2, + 4, + -1, + ], + [0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1], + [ + 0, + 12, + 4, + 2, + 3, + 2, + 12, + 10, + 6, + 4, + 2, + 3, + 2, + 6, + 10, + 9, + 1, + 5, + 4, + 4, + 3, + 6, + 2, + 4, + -1, + ], + [0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1], + ] + + def get_samples(self): + return [ + { + "instruction": "This is not an instruction.", + "input": "This is not an input.", + "output": "I never know what I'm doing, do you?", + }, + { + "instruction": "This is an instruction.", + "input": "This is an input.", + "output": "I always know what I'm doing, do you?", + }, + ] + + @mock.patch("torchtune.datasets._instruct.load_dataset") + def test_get_item_no_train_on_input(self, mock_load_dataset): + mock_load_dataset.return_value = Dataset.from_list(self.get_samples()) + prompt_lengths = (16, 14) + expected_labels = [ + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + + [1, 5, 4, 4, 3, 6, 2, 4, -1], + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + + [1, 6, 4, 4, 3, 6, 2, 4, -1], + ] + + dataset = InstructDataset( + tokenizer=DummyTokenizer(), + source="iam/agoofy/goober", + template=self.template, + transform=dummy_transform, + train_on_input=False, + ) + assert len(dataset) == 2 + mock_load_dataset.assert_called_once() + + for i in range(len(dataset)): + prompt, label = dataset[i]["tokens"], dataset[i]["labels"] + assert prompt == self.expected_tokenized_prompts[i] + assert label == expected_labels[i] + + @mock.patch("torchtune.datasets._instruct.load_dataset") + def test_get_item_train_on_input(self, mock_load_dataset): + mock_load_dataset.return_value = Dataset.from_list(self.get_samples()) + expected_labels = self.expected_tokenized_prompts + + dataset = InstructDataset( + tokenizer=DummyTokenizer(), + source="iam/agoofy/goober", + template=self.template, + transform=dummy_transform, + train_on_input=True, + ) + assert len(dataset) == 2 + mock_load_dataset.assert_called_once() + + for i in range(len(dataset)): + prompt, label = dataset[i]["tokens"], dataset[i]["labels"] + assert prompt == self.expected_tokenized_prompts[i] + assert label == expected_labels[i] diff --git a/training/tests/torchtune/datasets/test_packed_dataset.py b/training/tests/torchtune/datasets/test_packed_dataset.py new file mode 100644 index 0000000..8afb532 --- /dev/null +++ b/training/tests/torchtune/datasets/test_packed_dataset.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import pytest +import torch +from tests.test_utils import DummyTokenizer +from torch.utils.data import Dataset + +from torchtune.datasets import PackedDataset + + +class DummyDataset(Dataset): + def __init__(self, sample_size): + self.sample_size = sample_size + + def __getitem__(self, index): + if index >= 1000: + raise IndexError() + return { + "tokens": [index] * self.sample_size, + "labels": [index] * self.sample_size, + } + + def __len__(self): + return 1000 + + +class DummyRealDataset(Dataset): + def __init__(self): + self.samples_list = [ + "This is a packing test", + "A fantastic test. It should pack two samples.", + "This one will not be fully packed.", + ] + self.tokenizer = DummyTokenizer() + + def __getitem__(self, index): + tokens = self.tokenizer.encode(self.samples_list[index]) + return {"tokens": tokens, "labels": tokens} + + def __len__(self): + return len(self.samples_list) + + +class TestPackedDataset: + def _get_expected_mask_and_input_pos( + self, max_seq_len, sample_size, split_across_pack + ): + """ + Generate expected integer mask and position ids for given max sequence + length and sample length + """ + num_samples, remainder = divmod(max_seq_len, sample_size) + if split_across_pack and remainder > 0: + num_samples += 1 + mask = torch.block_diag( + *[ + torch.tril(torch.ones(sample_size, sample_size, dtype=torch.bool)) + for i in range(1, num_samples + 1) + ] + ) + input_pos = [list(range(sample_size)) for i in range(1, num_samples + 1)] + input_pos = list(itertools.chain(*input_pos)) + + # Emulate mask and position id padding + if not split_across_pack and remainder > 0: + mask = torch.block_diag( + mask, + torch.eye(remainder, dtype=torch.bool), + ) + input_pos.extend(list(range(sample_size, sample_size + remainder))) + + return mask[:max_seq_len, :max_seq_len], torch.tensor(input_pos[:max_seq_len]) + + def _calculate_num_packs( + self, dataset_size, max_seq_len, sample_size, split_across_pack, max_packs + ): + # First see how many samples we can fit in a single pack + num_samples_per_pack, remainder = divmod(max_seq_len, sample_size) + + # If we split across pack (and the samples don't fit perfectly in max_seq_len), we can fit more + if split_across_pack and remainder > 0: + # Now we need the fractional to see how many we can partially fit in each pack + num_samples_per_pack = max_seq_len / sample_size + + # If we don't split across pack, we will need more packs + num_packs, remainder = divmod(dataset_size, num_samples_per_pack) + + # If there's leftover, we need to add one more pack + if remainder > 0: + num_packs += 1 + + return num_packs if num_packs < max_packs else max_packs + + @pytest.mark.parametrize("max_seq_len", [25]) + @pytest.mark.parametrize("sample_size", [2, 5]) + @pytest.mark.parametrize("max_packs", [5, 200]) + @pytest.mark.parametrize("split_across_pack", [True, False]) + def test_packed_dataset( + self, max_seq_len, sample_size, max_packs, split_across_pack + ): + dataset = DummyDataset(sample_size) + packed = PackedDataset( + dataset, + max_seq_len=max_seq_len, + max_packs=max_packs, + split_across_pack=split_across_pack, + ) + + # Check we get right number of packs + correct_num_packs = self._calculate_num_packs( + len(dataset), max_seq_len, sample_size, split_across_pack, max_packs + ) + assert len(packed) == correct_num_packs + + # Check all fields are same length + assert ( + len(packed[0]["tokens"]) + == len(packed[0]["labels"]) + == len(packed[0]["mask"]) + == len(packed[0]["input_pos"]) + ) + # Check that samples are packed correctly - very last individual sample + # should have index value of the number of times dataset was iterated over + if split_across_pack: + # If we split samples, we'll know how many samples by taking the + # full length and dividing by sample size + last_index, remainder = divmod(len(packed) * max_seq_len, sample_size) + # Account for remaining sample that didn't fit in window + last_index = last_index if remainder > 0 else last_index - 1 + else: + # If we don't split samples, we know how many samples by taking + # how much fits in a single window and multiplying by max rows. + # If there is a remainder, this will end up being a pad token. + last_index = ( + (max_seq_len // sample_size) * len(packed) - 1 + if max_seq_len % sample_size == 0 + else 0 + ) + + assert packed[-1]["tokens"][-1].item() == last_index + + expected_mask, expected_input_pos = self._get_expected_mask_and_input_pos( + max_seq_len, sample_size, split_across_pack + ) + torch.testing.assert_close(packed[0]["mask"], expected_mask) + torch.testing.assert_close(packed[0]["input_pos"], expected_input_pos) + + def test_packed_dataset_real_data(self): + expected_tokenized_prompts = [ + torch.tensor([0, 4, 2, 1, 7, 4, -1, 0, 1, 9]), + torch.tensor([5, 2, 6, 4, 3, 8, -1, 0, 4, 3]), + torch.tensor([4, 3, 2, 5, 7, -1, 0, 0, 0, 0]), + ] + expected_tokenized_labels = [ + torch.tensor([0, 4, 2, 1, 7, 4, -1, 0, 1, 9]), + torch.tensor([5, 2, 6, 4, 3, 8, -1, 0, 4, 3]), + torch.tensor([4, 3, 2, 5, 7, -1, -100, -100, -100, -100]), + ] + expected_mask = [ + torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ), + torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ), + torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + ] + ), + ] + expected_input_pos = [ + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2]), + torch.tensor([3, 4, 5, 6, 7, 8, 9, 0, 1, 2]), + # Padded position ids cannot go beyond max seq_len - 1 + torch.tensor([3, 4, 5, 6, 7, 8, 9, 9, 9, 9]), + ] + packed = PackedDataset( + DummyRealDataset(), + max_seq_len=10, + split_across_pack=True, + ) + + for i in range(len(packed)): + prompt, label, mask, input_pos = ( + packed[i]["tokens"], + packed[i]["labels"], + packed[i]["mask"], + packed[i]["input_pos"], + ) + torch.testing.assert_close(prompt, expected_tokenized_prompts[i]) + torch.testing.assert_close(label, expected_tokenized_labels[i]) + torch.testing.assert_close(input_pos, expected_input_pos[i]) + torch.testing.assert_close(mask, expected_mask[i].to(dtype=torch.bool)) + + def test_pad_pack(self): + padding_idx = -8 + ignore_idx = -100 # Same as CROSS_ENTROPY_IGNORE_IDX + pack = { + "tokens": [2, 5], + "labels": [3, 7], + "seq_lens": [1, 1], + # Let the first token be the end of the previous sample (pos 8), + # and the second token the start of the next sample (pos 0). Collate + # should continue from 0 -> 1, 2, ... + "input_pos": [8, 0], + } + + dataset = DummyDataset(2) + packed = PackedDataset( + dataset, + max_seq_len=4, + ) + + pack = packed._convert_to_tensors(pack) + padded = packed._pad_pack(pack, padding_idx=padding_idx) + + padded_input = padded["tokens"] + padded_label = padded["labels"] + padded_input_pos = padded["input_pos"] + + torch.testing.assert_close( + padded_input, torch.tensor([2, 5, padding_idx, padding_idx]) + ) + torch.testing.assert_close( + padded_label, torch.tensor([3, 7, ignore_idx, ignore_idx]) + ) + torch.testing.assert_close(padded_input_pos, torch.tensor([8, 0, 1, 2])) + + def test_pack_errors_if_sample_too_long(self): + dataset = DummyDataset(8) + with pytest.raises(ValueError, match="Dataset sample is too long"): + PackedDataset( + dataset, + max_seq_len=4, + ) diff --git a/training/tests/torchtune/datasets/test_samsum_dataset.py b/training/tests/torchtune/datasets/test_samsum_dataset.py new file mode 100644 index 0000000..d9df3e8 --- /dev/null +++ b/training/tests/torchtune/datasets/test_samsum_dataset.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from tests.test_utils import DummyTokenizer +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX + +from torchtune.datasets import samsum_dataset + + +class TestSamsumDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + @patch("torchtune.datasets._sft.load_dataset") + def test_label_no_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is not masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "id": "13818513", + "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", + }, + ] + ) + + samsum_ds = samsum_dataset(model_transform=tokenizer, train_on_input=True) + input, labels = samsum_ds[0]["tokens"], samsum_ds[0]["labels"] + + assert input == [ + 0, + 9, + 4, + 9, + 7, + 1, + 5, + 8, + 2, + 3, + 4, + 5, + 6, + 5, + 7, + 4, + 5, + 3, + 8, + 3, + 3, + 8, + 6, + 5, + 7, + 3, + 4, + 5, + 5, + 4, + 9, + -1, + ] + assert labels == input + + @patch("torchtune.datasets._sft.load_dataset") + def test_label_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = Dataset.from_list( + [ + { + "id": "13818513", + "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", + }, + ] + ) + + samsum_ds = samsum_dataset(model_transform=tokenizer) + + # Generate the input and labels + input, labels = samsum_ds[0]["tokens"], samsum_ds[0]["labels"] + + assert input == [ + 0, + 9, + 4, + 9, + 7, + 1, + 5, + 8, + 2, + 3, + 4, + 5, + 6, + 5, + 7, + 4, + 5, + 3, + 8, + 3, + 3, + 8, + 6, + 5, + 7, + 3, + 4, + 5, + 5, + 4, + 9, + -1, + ] + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 22 diff --git a/training/tests/torchtune/datasets/test_slimorca_dataset.py b/training/tests/torchtune/datasets/test_slimorca_dataset.py new file mode 100644 index 0000000..404d99a --- /dev/null +++ b/training/tests/torchtune/datasets/test_slimorca_dataset.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from collections import Counter +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from tests.test_utils import DummyTokenizer +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX + +from torchtune.datasets import slimorca_dataset + + +class TestSlimOrcaDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + @patch("torchtune.datasets._sft.load_dataset") + @pytest.mark.parametrize("train_on_input", [True, False]) + def test_dataset_get_item(self, mock_load_dataset, train_on_input, tokenizer): + # Sample data from slimorca dataset + mock_load_dataset.return_value = Dataset.from_list( + [ + { + "conversations": [ + { + "from": "system", + "value": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950 + }, + { + "from": "human", + "value": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950 + }, + { + "from": "gpt", + "value": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950 + }, + ] + } + ] + ) + ds = slimorca_dataset( + model_transform=tokenizer, + train_on_input=train_on_input, + ) + # Generate the input and labels + input, labels = ds[0]["tokens"], ds[0]["labels"] + + expected_counts = { + 3: 28, + 2: 20, + 4: 20, + 5: 20, + 6: 17, + 10: 8, + 1: 7, + 8: 7, + 7: 7, + 9: 2, + 11: 2, + 0: 1, + 12: 1, + 17: 1, + -1: 1, + } + assert Counter(input) == expected_counts + if train_on_input: + assert Counter(labels) == expected_counts + else: + # Check that the input is masked + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 104 diff --git a/training/tests/torchtune/datasets/test_text_completion_dataset.py b/training/tests/torchtune/datasets/test_text_completion_dataset.py new file mode 100644 index 0000000..374d257 --- /dev/null +++ b/training/tests/torchtune/datasets/test_text_completion_dataset.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest import mock + +from tests.test_utils import DummyTokenizer + +from torchtune.datasets import TextCompletionDataset + + +class TestTextCompletionDataset: + expected_tokenized_prompts = [ + [0, 4, 2, 2, 7, 5, -1], + [0, 4, 2, 7, 7, 5, -1], + ] + + def get_samples(self): + return [ + { + "text": "This is an example text.", + }, + { + "text": "This is another example text.", + }, + ] + + @mock.patch("torchtune.datasets._text_completion.load_dataset") + def test_get_item(self, mock_load_dataset): + mock_load_dataset.return_value = self.get_samples() + expected_labels = self.expected_tokenized_prompts + + dataset = TextCompletionDataset( + tokenizer=DummyTokenizer(), + source="iam/agoofy/goober", + column="text", + max_seq_len=100, + ) + assert len(dataset) == 2 + mock_load_dataset.assert_called_once() + + for i in range(len(dataset)): + prompt, label = dataset[i]["tokens"], dataset[i]["labels"] + assert prompt == self.expected_tokenized_prompts[i] + assert label == expected_labels[i] + + @mock.patch("torchtune.datasets._text_completion.load_dataset") + def test_get_item_no_eos(self, mock_load_dataset): + mock_load_dataset.return_value = self.get_samples() + expected_labels = self.expected_tokenized_prompts + + dataset = TextCompletionDataset( + tokenizer=DummyTokenizer(), + source="iam/agoofy/goober", + column="text", + max_seq_len=100, + add_eos=False, + ) + assert len(dataset) == 2 + mock_load_dataset.assert_called_once() + + for i in range(len(dataset)): + prompt, label = dataset[i]["tokens"], dataset[i]["labels"] + # trimming EOS IDs from the expected tokens, assertion is against: + # [0, 4, 2, 2, 7, 5] + assert prompt == self.expected_tokenized_prompts[i][:-1] + assert label == expected_labels[i][:-1] diff --git a/training/tests/torchtune/datasets/test_wikitext_dataset.py b/training/tests/torchtune/datasets/test_wikitext_dataset.py new file mode 100644 index 0000000..ce3afbf --- /dev/null +++ b/training/tests/torchtune/datasets/test_wikitext_dataset.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from unittest.mock import patch + +import pytest + +from tests.test_utils import DummyTokenizer + +from torchtune.datasets import wikitext_dataset + + +class TestWikiTextDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + @patch("torchtune.datasets._text_completion.load_dataset") + @pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) + def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): + # Sample data from wikitext dataset + load_dataset.return_value = [ + { + "page": "Bart , like the rest of his family , has yellow skin . " + "Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . " + "When the Simpson family goes to church in the episodes , or to school " + "events or shows , Bart wears a blue suit with a white shirt , a purple " + "tie , blue shorts and a blue jacket .", + } + ] + ds = wikitext_dataset( + tokenizer=tokenizer, + max_seq_len=max_seq_len, + ) + input, label = ds[0]["tokens"], ds[0]["labels"] + assert len(input) <= max_seq_len + assert len(label) <= max_seq_len + assert len(input) == len(label) + assert input[0] == tokenizer.bos_id diff --git a/training/tests/torchtune/models/__init__.py b/training/tests/torchtune/models/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/models/clip/__init__.py b/training/tests/torchtune/models/clip/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/models/clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/models/clip/test_clip_image_transform.py b/training/tests/torchtune/models/clip/test_clip_image_transform.py new file mode 100644 index 0000000..4f87c23 --- /dev/null +++ b/training/tests/torchtune/models/clip/test_clip_image_transform.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import PIL +import pytest + +import torch +from tests.test_utils import assert_expected + +from torchtune.models.clip._transforms import CLIPImageTransform +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestCLIPImageTransform: + @pytest.mark.parametrize( + "params", + [ + { + "image_size": (100, 400, 3), + "expected_shape": torch.Size([2, 3, 224, 224]), + "resize_to_max_canvas": False, + "expected_tile_means": [0.2230, 0.1763], + "expected_tile_max": [1.0, 1.0], + "expected_tile_min": [0.0, 0.0], + "expected_aspect_ratio": [1, 2], + }, + { + "image_size": (1000, 300, 3), + "expected_shape": torch.Size([4, 3, 224, 224]), + "resize_to_max_canvas": True, + "expected_tile_means": [0.5007, 0.4995, 0.5003, 0.1651], + "expected_tile_max": [0.9705, 0.9694, 0.9521, 0.9314], + "expected_tile_min": [0.0353, 0.0435, 0.0528, 0.0], + "expected_aspect_ratio": [4, 1], + }, + { + "image_size": (200, 200, 3), + "expected_shape": torch.Size([4, 3, 224, 224]), + "resize_to_max_canvas": True, + "expected_tile_means": [0.5012, 0.5020, 0.5011, 0.4991], + "expected_tile_max": [0.9922, 0.9926, 0.9970, 0.9908], + "expected_tile_min": [0.0056, 0.0069, 0.0059, 0.0033], + "expected_aspect_ratio": [2, 2], + }, + { + "image_size": (600, 200, 3), + "expected_shape": torch.Size([3, 3, 224, 224]), + "resize_to_max_canvas": False, + "expected_tile_means": [0.4473, 0.4469, 0.3032], + "expected_tile_max": [1.0, 1.0, 1.0], + "expected_tile_min": [0.0, 0.0, 0.0], + "expected_aspect_ratio": [3, 1], + }, + ], + ) + def test_clip_image_transform(self, params): + # Initialize the image transformation with specified parameters + image_transform = CLIPImageTransform( + image_mean=None, + image_std=None, + tile_size=224, + possible_resolutions=None, + max_num_tiles=4, + resample="bilinear", + resize_to_max_canvas=params["resize_to_max_canvas"], + ) + # Generate a deterministic image using np.arange for reproducibility + image_size = params["image_size"] + image = ( + np.random.randint(0, 256, np.prod(image_size)) + .reshape(image_size) + .astype(np.uint8) + ) + image = PIL.Image.fromarray(image) + + # Apply the transformation + output = image_transform(image=image) + output_image = output["image"] + output_ar = output["aspect_ratio"] + + # Check if the output shape matches the expected shape + assert ( + output_image.shape == params["expected_shape"] + ), f"Expected shape {params['expected_shape']} but got {output_image.shape}" + + # Check if the pixel values are within the expected range [0, 1] + assert ( + 0 <= output_image.min() <= output_image.max() <= 1 + ), f"Expected pixel values to be in range [0, 1] but got {output_image.min()} and {output_image.max()}" + + # Check if the mean, max, and min values of the tiles match the expected values + for i, tile in enumerate(output_image): + assert_expected( + tile.mean().item(), params["expected_tile_means"][i], rtol=0, atol=1e-4 + ) + assert_expected( + tile.max().item(), params["expected_tile_max"][i], rtol=0, atol=1e-4 + ) + assert_expected( + tile.min().item(), params["expected_tile_min"][i], rtol=0, atol=1e-4 + ) + + # aspect ratio matches the expected aspect ratio + assert tuple(output_ar.numpy()) == tuple( + params["expected_aspect_ratio"] + ), f"Expected aspect ratio {params['expected_aspect_ratio']} but got {tuple(output_ar.numpy())}" + + # number of tiles matches the product of the aspect ratio + expected_num_tiles = output_ar[0] * output_ar[1] + assert ( + expected_num_tiles == output_image.shape[0] + ), f"Expected {expected_num_tiles} tiles but got {output_image.shape[0]}" diff --git a/training/tests/torchtune/models/clip/test_positional_embeddings.py b/training/tests/torchtune/models/clip/test_positional_embeddings.py new file mode 100644 index 0000000..4c393d1 --- /dev/null +++ b/training/tests/torchtune/models/clip/test_positional_embeddings.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from tests.test_utils import assert_expected, fixed_init_model, fixed_init_tensor +from torchtune.models.clip._position_embeddings import ( + TiledTokenPositionalEmbedding, + TilePositionalEmbedding, + TokenPositionalEmbedding, +) + + +class TestPositionalEmbeddings: + @pytest.fixture(autouse=True) + def setup_class(self): + + self.embed_dim = 16 + self.tile_size = 14 + self.max_num_tiles = 3 + self.bsz_and_n_imgs = 2 + self.patch_size = 2 + self.aspect_ratio = torch.tensor([[3, 1], [1, 2]]) + self.patch_grid_size = self.tile_size // self.patch_size + + input_tensor = torch.randn( + ( + self.bsz_and_n_imgs, + self.max_num_tiles, + self.patch_grid_size**2 + 1, + self.embed_dim, + ) + ) + self.input_tensor = fixed_init_tensor(input_tensor.shape, min_val=-1, max_val=1) + + def test_token_positional_embedding(self): + # call model + embedding = TokenPositionalEmbedding( + self.embed_dim, patch_size=self.patch_size, tile_size=self.tile_size + ) + fixed_init_model(embedding, min_val=-1, max_val=1) + + inpt = self.input_tensor.clone().reshape( + self.bsz_and_n_imgs * self.max_num_tiles, -1, self.embed_dim + ) + output = embedding(inpt) + + # assertion + assert_expected(output.shape, inpt.shape) + assert_expected(output.mean(), torch.tensor(-0.001458), atol=1e-3, rtol=1e-3) + + def test_tiled_token_positional_embedding(self): + # call model + embedding = TiledTokenPositionalEmbedding( + self.max_num_tiles, + self.embed_dim, + patch_size=self.patch_size, + tile_size=self.tile_size, + ) + fixed_init_model(embedding, min_val=-1, max_val=1) + + # replace gate 0 -> 0.5 + embedding.gate = torch.nn.Parameter(torch.full(embedding.gate.shape, 0.5)) + + inpt = self.input_tensor.clone() + output = embedding(inpt, self.aspect_ratio) + + # assertion + assert_expected(output.shape, self.input_tensor.shape) + assert_expected(output.mean(), torch.tensor(-0.17208), atol=1e-3, rtol=1e-3) + + def test_tile_positional_embedding(self): + # call model + embedding = TilePositionalEmbedding(self.max_num_tiles, self.embed_dim) + fixed_init_model(embedding, min_val=-1, max_val=1) + + inpt = self.input_tensor.clone() + output = embedding(inpt, self.aspect_ratio) + + # assertion + assert_expected(output.shape, self.input_tensor.shape) + assert_expected(output.mean(), torch.tensor(0.28627), atol=1e-3, rtol=1e-3) diff --git a/training/tests/torchtune/models/gemma/test_gemma_tokenizer.py b/training/tests/torchtune/models/gemma/test_gemma_tokenizer.py new file mode 100644 index 0000000..b5aa8b4 --- /dev/null +++ b/training/tests/torchtune/models/gemma/test_gemma_tokenizer.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.data import Message +from torchtune.models.gemma import gemma_tokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestGemmaTokenizer: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + return gemma_tokenizer(str(ASSETS / "m.model")) + + def test_tokenize_messages(self, tokenizer): + messages = [ + Message( + role="user", + content="Below is an instruction that describes a task. Write a response " + "that appropriately completes the request.\n\n### Instruction:\nGenerate " + "a realistic dating profile bio.\n\n### Response:\n", + masked=True, + ), + Message( + role="assistant", + content="I'm an outgoing and friendly person who loves spending time with " + "friends and family. I'm also a big-time foodie and love trying out new " + "restaurants and different cuisines. I'm a big fan of the arts and enjoy " + "going to museums and galleries. I'm looking for someone who shares my " + "interest in exploring new places, as well as someone who appreciates a " + "good conversation over coffee.", + ), + ] + tokens, mask = tokenizer.tokenize_messages(messages) + expected_tokens = [ + 1, + 323, + 418, + 202, + 31, + 128, + 15, + 120, + 47, + 88, + 584, + 23, + 1665, + 182, + 9, + 434, + 295, + 85, + 4, + 780, + 47, + 636, + 9, + 1094, + 213, + 23, + 9, + 69, + 69, + 164, + 1153, + 299, + 35, + 961, + 132, + 237, + 7, + 5, + 761, + 4, + 12, + 0, + 313, + 120, + 47, + 88, + 584, + 166, + 493, + 171, + 54, + 299, + 9, + 906, + 244, + 19, + 186, + 767, + 303, + 671, + 92, + 209, + 24, + 190, + 52, + 38, + 4, + 12, + 0, + 1243, + 7, + 69, + 135, + 213, + 166, + 6, + 21, + 45, + 128, + 71, + 58, + 38, + 14, + 10, + 652, + 35, + 462, + 101, + 1306, + 7, + 341, + 171, + 20, + 14, + 127, + 26, + 652, + 7, + 10, + 1268, + 4, + 6, + 21, + 45, + 591, + 9, + 566, + 22, + 994, + 913, + 38, + 20, + 52, + 24, + 10, + 1306, + 734, + 14, + 71, + 365, + 1382, + 7, + 10, + 801, + 105, + 88, + 244, + 985, + 7, + 4, + 6, + 21, + 45, + 9, + 566, + 126, + 180, + 11, + 5, + 1137, + 7, + 10, + 1089, + 151, + 8, + 1156, + 213, + 342, + 7, + 10, + 384, + 104, + 54, + 470, + 4, + 6, + 21, + 45, + 287, + 14, + 33, + 125, + 135, + 24, + 101, + 512, + 66, + 7, + 28, + 822, + 15, + 542, + 69, + 59, + 110, + 14, + 365, + 229, + 7, + 3, + 36, + 267, + 36, + 125, + 135, + 24, + 101, + 1503, + 182, + 9, + 222, + 1661, + 191, + 332, + 92, + 92, + 24, + 24, + 4, + 2, + ] + expected_mask = [True] * 75 + [False] * 125 + assert expected_tokens == tokens + assert expected_mask == mask diff --git a/training/tests/torchtune/models/llama2/scripts/README.md b/training/tests/torchtune/models/llama2/scripts/README.md new file mode 100644 index 0000000..a6bafb3 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/README.md @@ -0,0 +1,15 @@ +# Verifying Correctness against Reference Implementations + +This repository puts a high bar on correctness and testing. To make sure our model and +module implementations are correct, we compare our implementation against reference implementations +where possible. This folder contains scripts used for these comparisons. + + +## Running the scripts + +You can run the scripts using the following command as an example. +Each script should print out the value being used in the associated unit tests. + +``` +python3 -m tests.llm.llama2.scripts.compare_attention +``` diff --git a/training/tests/torchtune/models/llama2/scripts/__init__.py b/training/tests/torchtune/models/llama2/scripts/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/models/llama2/scripts/compare_attention.py b/training/tests/torchtune/models/llama2/scripts/compare_attention.py new file mode 100644 index 0000000..1efc98c --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_attention.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torch import nn + +from torchtune.modules import CausalSelfAttention, RotaryPositionalEmbeddings + + +""" +Reference implementation of Attention from: +https://github.com/facebookresearch/llama/blob/main/llama/model.py#L176 + +Replicating code here to minimize dependencies. The code is modified to +remove dependencies like FAIRSCale and features like KV Caching. +The latter is not supported by the current implementation. +""" + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, n_heads: int, n_kv_heads: int, dim: int): + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_heads = n_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = dim // n_heads + self.dim = dim + + self.wq = nn.Linear(self.dim, self.dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.dim, self.dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: torch.Tensor, + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = nn.functional.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + output = self.wo(output) + return output + + +""" +Reference implementation of RoPE from: +https://github.com/facebookresearch/llama/blob/main/llama/model.py#L80 + +The original code structures this as stand-alone functions instead of +a class. Replicating code here to minimize dependencies. +""" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"{freqs_cis.shape} does not match {x.shape[1], x.shape[-1]}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, x_) + x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) + return x_out.type_as(x) + + +def compare_rope( + bsz: int, num_heads: int, embed_dim: int, seq_len: int, max_seq_len: int +) -> None: + + # make sure we have the right seed for generating outputs + torch.manual_seed(0) + + head_dim = embed_dim // num_heads + + # generate input tensor + x = torch.randn(bsz, seq_len, num_heads, head_dim) + + # Compute the reference tensors + freq_cis = precompute_freqs_cis(dim=head_dim, end=max_seq_len * 2) + x_out_ref = apply_rotary_emb(x, freqs_cis=freq_cis[:seq_len]) + + # Compute the tensors from current implementation + rope_emb = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + x_out = rope_emb(x) + + # Validate correctness + torch.testing.assert_close(x_out_ref, x_out, atol=1e-6, rtol=1e-5) + + # value: tensor(6.4543e-05) + print(x_out.mean()) + + # value: tensor(2165.7053) + print(x_out.sum()) + + # value: tensor(5.4546) + print(x_out.max()) + + curr_pos = 10 + x_out_ref = apply_rotary_emb(x, freqs_cis=freq_cis[curr_pos : curr_pos + seq_len]) + x_out = rope_emb(x, curr_pos=10) + + # Validate correctness + torch.testing.assert_close(x_out_ref, x_out, atol=1e-6, rtol=1e-5) + + # value: tensor(0.0002) + print(x_out.mean()) + + # value: tensor(5158.3159) + print(x_out.sum()) + + # value: tensor(5.4543) + print(x_out.max()) + + +def compare_attention( + bsz: int, + seq_len: int, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, +) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(16) + + head_dim = embed_dim // num_heads + + # generate input tensor used by both implementations + input_t = torch.randn(bsz, seq_len, embed_dim) + + # generate mask and frequencies tensor needed for the reference + # implementation + mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + freq_cis = precompute_freqs_cis(dim=head_dim, end=seq_len) + + # reference implementation; initialize with constant to compare outputs + attn_ref = Attention(n_heads=num_heads, n_kv_heads=num_kv_heads, dim=embed_dim) + for p in attn_ref.parameters(): + nn.init.constant_(p, 0.05) + + with torch.no_grad(): + attn_out_ref = attn_ref(input_t, freq_cis, mask) + + # current implementation; initialize with constant to compare outputs + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=0.0, + ) + for p in attn.parameters(): + nn.init.constant_(p, 0.05) + + with torch.no_grad(): + attn_out = attn(input_t) + + # value: tensor(-27.5074) + print(attn_out.mean()) + + # output tensors should be similar + torch.testing.assert_close(attn_out, attn_out_ref, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare Attention implementations") + parser.add_argument("--bsz", type=int, default=4, help="Batch size of input tensor") + parser.add_argument( + "--seq_len", type=int, default=2048, help="input sequence length" + ) + parser.add_argument( + "--embed_dim", + type=int, + default=4096, + help="Embedding dimension used to compute the dim for RopE", + ) + parser.add_argument( + "--num_heads", + type=int, + default=32, + help="Number of heads in the attention layer", + ) + parser.add_argument( + "--num_kv_heads", + type=int, + default=8, + help="Number of key/value heads in the attention layer", + ) + parser.add_argument( + "--max_seq_len", type=int, default=4096, help="max sequence length" + ) + + args = parser.parse_args() + + compare_rope( + args.bsz, args.num_heads, args.embed_dim, args.seq_len, args.max_seq_len + ) + + compare_attention( + args.bsz, + args.seq_len, + args.embed_dim, + args.num_heads, + args.num_kv_heads, + args.max_seq_len, + ) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_decoder.py b/training/tests/torchtune/models/llama2/scripts/compare_decoder.py new file mode 100644 index 0000000..984dcbb --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from tests.torchtune.models.llama2.scripts.compare_attention import precompute_freqs_cis +from tests.torchtune.models.llama2.scripts.compare_decoder_layer import ( + RMSNormRef, + TransformerBlock, +) + +from torch import nn + +from torchtune.models.llama2 import llama2 + +""" +Reference implementation of Transformer from: +https://github.com/facebookresearch/llama/blob/main/llama/model.py#L413 + +Replicating code here to minimize dependencies. The code is modified to +include params for the constructor and remove start_pos (not supported). +""" + +# TODO: Move this to standalone ref implementation +class Transformer(nn.Module): + def __init__( + self, + vocab_size: int, + dim: int, + n_layers: int, + n_heads: int, + max_seq_len: int, + n_kv_heads: int, + ): + + super().__init__() + self.vocab_size = vocab_size + self.n_layers = n_layers + + self.tok_embeddings = nn.Embedding(vocab_size, dim) + + self.layers = torch.nn.ModuleList() + for _ in range(n_layers): + self.layers.append( + TransformerBlock(n_heads=n_heads, dim=dim, n_kv_heads=n_kv_heads) + ) + + self.norm = RMSNormRef(dim) + self.output = nn.Linear(dim, vocab_size, bias=False) + + self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2) + + def forward(self, tokens: torch.Tensor, start_pos: int = 0): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.layers: + h = layer(h, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output + + +def compare_decoder( + bsz: int, + seq_len: int, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + vocab_size: int, + num_layers: int, +) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(16) + + # generate input tensor used by both implementations + x_input = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len)) + + # reference implementation; initialize with constant to compare outputs + decoder_ref = Transformer( + vocab_size=vocab_size, + dim=embed_dim, + n_layers=num_layers, + n_heads=num_heads, + max_seq_len=max_seq_len, + n_kv_heads=num_kv_heads, + ) + for p in model.parameters(): + nn.init.constant_(decoder_ref, 0.2) + + with torch.no_grad(): + output_ref = decoder_ref(x_input) + + # current implementation; initialize with constant to compare outputs + decoder = llama2( + vocab_size=vocab_size, + embed_dim=embed_dim, + num_layers=num_layers, + num_heads=num_heads, + max_seq_len=max_seq_len, + num_kv_heads=num_kv_heads, + ) + for p in model.parameters(): + nn.init.constant_(decoder, 0.2) + + with torch.no_grad(): + output = decoder(x_input) + + # value: tensor(20.4800) + print(output.mean()) + + torch.testing.assert_close(output_ref, output, atol=1e-6, rtol=1e-6) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare Attention implementations") + parser.add_argument("--bsz", type=int, default=4, help="Batch size of input tensor") + parser.add_argument( + "--seq_len", type=int, default=512, help="input sequence length" + ) + parser.add_argument( + "--embed_dim", + type=int, + default=512, + help="Embedding dimension used to compute the dim for RopE", + ) + parser.add_argument( + "--num_heads", + type=int, + default=8, + help="Number of heads in the attention layer", + ) + parser.add_argument( + "--num_kv_heads", + type=int, + default=8, + help="Number of key/value heads in the attention layer", + ) + parser.add_argument( + "--max_seq_len", type=int, default=512, help="max sequence length" + ) + parser.add_argument("--vocab_size", type=int, default=256, help="vocab size") + parser.add_argument( + "--num_layers", type=int, default=4, help="number of transformer layers" + ) + + args = parser.parse_args() + + compare_decoder( + args.bsz, + args.seq_len, + args.embed_dim, + args.num_heads, + args.num_kv_heads, + args.max_seq_len, + args.vocab_size, + args.num_layers, + ) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_decoder_layer.py b/training/tests/torchtune/models/llama2/scripts/compare_decoder_layer.py new file mode 100644 index 0000000..f7c2015 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_decoder_layer.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from tests.torchtune.models.llama2.scripts.compare_attention import ( + Attention, + precompute_freqs_cis, +) +from tests.torchtune.models.llama2.scripts.compare_feed_forward import FeedForwardRef + +from torch import nn +from torchtune.models.llama2._model_utils import scale_hidden_dim_for_mlp + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoderLayer, +) + + +""" +Reference implementation of RMSNorm from: +https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34 + +Replicating code here to minimize dependencies. The code is modified to +include params for the constructor and remove start_pos (not supported). +""" + + +class RMSNormRef(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +""" +Reference implementation of Transformer Decoder Layer from: +https://github.com/facebookresearch/llama/blob/main/llama/model.py#L351 + +Replicating code here to minimize dependencies. The code is modified to +include params for the constructor and remove start_pos (not supported). +""" + + +class TransformerBlock(nn.Module): + def __init__(self, n_heads: int, dim: int, n_kv_heads: int): + super().__init__() + self.n_heads = n_heads + self.dim = dim + # self.head_dim = args.dim // args.n_heads + self.attention = Attention(n_heads=n_heads, n_kv_heads=n_kv_heads, dim=dim) + self.feed_forward = FeedForwardRef(dim=dim, hidden_dim=4 * dim) + self.attention_norm = RMSNormRef(dim=dim) + self.ffn_norm = RMSNormRef(dim=dim) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + norm_out = self.attention_norm(x) + attn_out = self.attention.forward(norm_out, freqs_cis, mask) + h = x + attn_out + ffn_norm_out = self.ffn_norm(h) + mlp_out = self.feed_forward.forward(ffn_norm_out) + out = h + mlp_out + return out + + +def compare_decoder_layer( + bsz: int, + seq_len: int, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, +) -> None: + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(16) + + head_dim = embed_dim // num_heads + + # generate input tensor used by both implementations + input_t = torch.randn(bsz, seq_len, embed_dim) + + # generate mask and frequencies tensor needed for the reference + # implementation + mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + freq_cis = precompute_freqs_cis(dim=head_dim, end=seq_len) + + # reference implementation; initialize with constant to compare outputs + transformer_block = TransformerBlock( + n_heads=num_heads, n_kv_heads=num_kv_heads, dim=embed_dim + ) + for p in transformer_block.parameters(): + nn.init.constant_(p, 0.05) + + with torch.no_grad(): + block_out = transformer_block(x=input_t, freqs_cis=freq_cis, mask=mask) + + # current implementation; initialize with constant to compare outputs + norm_eps = 1e-5 + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=0.0, + ) + hidden_dim = _scale_hidden_dim_for_mlp(embed_dim) + mlp = FeedForward( + dim=embed_dim, hidden_dim=hidden_dim, linear_class=torch.nn.Linear + ) + transformer_layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + for p in transformer_layer.parameters(): + nn.init.constant_(p, 0.05) + + with torch.no_grad(): + layer_out = transformer_layer(input_t) + + # value: torch.tensor(18261.0156) + print(layer_out.mean()) + + torch.testing.assert_close(block_out, layer_out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare Attention implementations") + parser.add_argument("--bsz", type=int, default=4, help="Batch size of input tensor") + parser.add_argument( + "--seq_len", type=int, default=2048, help="input sequence length" + ) + parser.add_argument( + "--embed_dim", + type=int, + default=4096, + help="Embedding dimension used to compute the dim for RopE", + ) + parser.add_argument( + "--num_heads", + type=int, + default=32, + help="Number of heads in the attention layer", + ) + parser.add_argument( + "--num_kv_heads", + type=int, + default=8, + help="Number of key/value heads in the attention layer", + ) + parser.add_argument( + "--max_seq_len", type=int, default=4096, help="max sequence length" + ) + + args = parser.parse_args() + + compare_decoder_layer( + args.bsz, + args.seq_len, + args.embed_dim, + args.num_heads, + args.num_kv_heads, + args.max_seq_len, + ) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_feed_forward.py b/training/tests/torchtune/models/llama2/scripts/compare_feed_forward.py new file mode 100644 index 0000000..973cd6f --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_feed_forward.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from tests.test_utils import fixed_init_model + +from torch import nn +from torchtune.models.llama2._model_utils import scale_hidden_dim_for_mlp + +from torchtune.modules import FeedForward + + +""" +Reference implementation of FeedForward from: +https://github.com/facebookresearch/llama/blob/main/llama/model.py#L307 + +Replicating code here to minimize dependencies. +""" + + +class FeedForwardRef(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +def compare_feed_forward(embed_dim: int, hidden_dim: int) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(0) + + # generate input tensor used by both implementations + input_t = torch.randn(1, embed_dim) + + # reference implementation; initialize with constant to compare outputs + ff_ref = FeedForwardRef(dim=embed_dim, hidden_dim=4 * embed_dim) + fixed_init_model(ff_ref) + + with torch.no_grad(): + ff_out_ref = ff_ref(input_t) + + hidden_dim = _scale_hidden_dim_for_mlp(embed_dim) + ff = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=torch.nn.Linear) + fixed_init_model(ff) + + with torch.no_grad(): + ff_out = ff(input_t) + + torch.testing.assert_close(ff_out, ff_out_ref, atol=1e-5, rtol=1e-5) + print(ff_out.mean()) + print(ff_out.max()) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare RMS Norm implementations") + parser.add_argument( + "--embed_dim", + type=int, + default=4096, + help="Embedding dimension used to compute the dim for RopE", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=4096, + help="Hidden dimension in the feed forward layer", + ) + + args = parser.parse_args() + + compare_feed_forward(args.embed_dim, args.hidden_dim) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_fused_attention.py b/training/tests/torchtune/models/llama2/scripts/compare_fused_attention.py new file mode 100644 index 0000000..15d8138 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_fused_attention.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from tests.test_utils import fixed_init_model +from torch import nn, Tensor +from torchtune.modules import CausalSelfAttention, KVCache, RotaryPositionalEmbeddings + +# Copy-paste of fused attention for comparison +class FusedCausalSelfAttention(nn.Module): + """Multi-headed grouped query self-attention (GQA) layer introduced + in https://arxiv.org/pdf/2305.13245v1.pdf. + + GQA is a version of multiheaded attention (MHA) which uses fewer + key/value heads than query heads by grouping n query heads for each + key and value head. Multi-Query Attention is an extreme + version where we have a single key and value head shared by all + query heads. + + Following is an example of MHA, GQA and MQA with num_heads = 4 + + (credit for the documentation: + https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py). + + + :: + + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ │ │ │ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + MHA GQA MQA + n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + head_dim (int): dimension of each head, calculated by ``embed_dim`` // ``num_heads``. + qkv_proj (nn.Module): projection layer for query, key and value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (nn.Module): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value. + If not specified, then no caching is used. + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + + Raises: + ValueError: If `num_heads` % `num_kv_heads` != 0 + ValueError: If `embed_dim` % `num_heads` != 0 + ValueError: If `attn_dropout` < 0 or > 1 + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + qkv_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: nn.Module, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + + # Set layers + self.kv_cache = kv_cache + self.qkv_proj = qkv_proj + self.output_proj = output_proj + self.pos_embeddings = pos_embeddings + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + curr_pos: int = 0, + ) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [batch_size x seq_length x embed_dim] + mask (Optional[Tensor]): boolean mask, defaults to None. + curr_pos (int): current position in the sequence, defaults to 0. + + Returns: + Tensor: output tensor with attention applied + + Raises: + ValueError: if seq_len of x is bigger than max_seq_len + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + - qkv_d: qkv_dim computed as (n_h + 2 * n_kv) * h_d + + TODO: + - Return the attention weights + - Make application of positional embeddings optional + """ + + # input has shape [b, s, d] + bsz, seq_len, _ = x.shape + + if seq_len > self.max_seq_len: + raise ValueError( + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" + ) + + # qkv has shape [b, s, qkv_d] + qkv = self.qkv_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + + # Each key and value either has a single query (MHA) + # or q_per_kv queries (MQA/GQA). total_qkv will be 3 + # for MHA + total_qkv = q_per_kv + 2 + + # decompose the last dimension into n_kv x total_qkv, h_d + qkv = qkv.view(bsz, seq_len, self.num_kv_heads, total_qkv, self.head_dim) + + # create the q,k and v tensors by splitting qkv + # q: [b, s, n_kv, q_per_kv, h_d] + # k: [b, s, n_kv, 1, h_d] + # v: [b, s, n_kv, 1, h_d] + q, k, v = qkv.split((q_per_kv, 1, 1), dim=3) + + # if needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) + + # llama2 applies the RoPE embeddings on tensors with shape + # [b, s, n_h, h_d] + # Reshape the tensors before we apply RoPE + q = q.reshape(bsz, seq_len, -1, self.head_dim) + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # Apply positional embeddings + q = self.pos_embeddings(q, curr_pos) + k = self.pos_embeddings(k, curr_pos) + + # Update key-value cache + if self.kv_cache is not None: + k, v = self.kv_cache.update( + bsz=bsz, seq_len=seq_len, curr_pos=curr_pos, k_val=k, v_val=v + ) + + # [b, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ + output = nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.attn_dropout, + is_causal=self.kv_cache is None, + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return self.output_proj(output) + + +def map_state_dict( + sd, + head_dim: int, + num_heads: int, + num_kv_heads: int, +): + mapped_sd = {k: v for k, v in sd.items() if "qkv_proj" not in k} + q_per_kv = num_heads // num_kv_heads + slice_size = q_per_kv + 2 + ind = range(head_dim * num_kv_heads * slice_size) + qkv = sd["qkv_proj.weight"] + q_ind = list(filter(lambda x: (x // head_dim) % slice_size < slice_size - 2, ind)) + k_ind = list(filter(lambda x: (x // head_dim) % slice_size == slice_size - 2, ind)) + v_ind = list(filter(lambda x: (x // head_dim) % slice_size == slice_size - 1, ind)) + q = qkv.index_select(0, torch.tensor(q_ind)) + k = qkv.index_select(0, torch.tensor(k_ind)) + v = qkv.index_select(0, torch.tensor(v_ind)) + mapped_sd["q_proj.weight"] = q + mapped_sd["k_proj.weight"] = k + mapped_sd["v_proj.weight"] = v + return mapped_sd + + +def _get_mask(inpt: Tensor) -> Tensor: + seq_len = inpt.shape[1] + mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=inpt.device) + mask = torch.triu(mask, diagonal=1).type_as(inpt) + return mask + + +def compare_attn( + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + use_kv_cache: bool, +): + + torch.manual_seed(16) + inputs = torch.randn(4, 2048, 4096) + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + qkv_dim = (num_heads + 2 * num_kv_heads) * head_dim + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + if use_kv_cache: + kv_cache = KVCache( + batch_size=4, + max_seq_len=max_seq_len, + n_kv_heads=num_heads, + head_dim=head_dim, + ) + else: + kv_cache = None + + attn_ref = FusedCausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + qkv_proj=nn.Linear(embed_dim, qkv_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=kv_cache, + max_seq_len=max_seq_len, + ) + fixed_init_model(attn_ref) + attn_ref.eval() + + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=kv_cache, + max_seq_len=max_seq_len, + ) + mapped_sd = map_state_dict(attn_ref.state_dict(), head_dim, num_heads, num_kv_heads) + attn.load_state_dict(mapped_sd) + + # Compare fused and non-fused with remapped state dict + with torch.no_grad(): + if use_kv_cache: + mask = _get_mask(inputs) + out_ref = attn_ref(inputs, mask, curr_pos=0) + out = attn_ref(inputs, mask, curr_pos=0) + else: + out_ref = attn_ref(inputs) + out = attn(inputs) + print( + "These values should match the original unit test", out_ref.mean(), out.mean() + ) + torch.testing.assert_close(out_ref, out, atol=1e-8, rtol=1e-3) + + # Determine the new value with fixed initialization + fixed_init_model(attn) + with torch.no_grad(): + if use_kv_cache: + new_out = attn(inputs, mask, curr_pos=0) + else: + new_out = attn(inputs) + print(f"New unit test value: {new_out.mean()}") + + +if __name__ == "__main__": + + # compare mha + mha = { + "num_heads": 32, + "embed_dim": 4096, + "max_seq_len": 4096, + "num_kv_heads": None, + "use_kv_cache": False, + } + mqa = { + "num_heads": 32, + "embed_dim": 4096, + "max_seq_len": 4096, + "num_kv_heads": 1, + "use_kv_cache": False, + } + gqa = { + "num_heads": 32, + "embed_dim": 4096, + "max_seq_len": 4096, + "num_kv_heads": 8, + "use_kv_cache": False, + } + mha_kv = { + "num_heads": 32, + "embed_dim": 4096, + "max_seq_len": 4096, + "num_kv_heads": None, + "use_kv_cache": True, + } + mqa_kv = { + "num_heads": 32, + "embed_dim": 4096, + "max_seq_len": 4096, + "num_kv_heads": 1, + "use_kv_cache": True, + } + gqa_kv = { + "num_heads": 32, + "embed_dim": 4096, + "max_seq_len": 4096, + "num_kv_heads": 8, + "use_kv_cache": True, + } + test_cases = { + "mha": mha, + "mqa": mqa, + "gqa": gqa, + "mha_kv": mha_kv, + "mqa_kv": mqa_kv, + "gqa_kv": gqa_kv, + } + + for test_case, params in test_cases.items(): + print(f"For test case {test_case}") + compare_attn(**params) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_lora.py b/training/tests/torchtune/models/llama2/scripts/compare_lora.py new file mode 100644 index 0000000..2f2f109 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_lora.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from tests.test_utils import fixed_init_model +from torch import nn +from torchtune.modules.peft.lora import LoRALinear + +# Reference implementation from +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +class LoRALayer: + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + +class LoRALinearRef(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs, + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__( + self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights, + ) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, "lora_A"): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode: bool = True): + def T(w): # noqa + return w.transpose(0, 1) if self.fan_in_fan_out else w + + nn.Linear.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): # noqa + return w.transpose(0, 1) if self.fan_in_fan_out else w + + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + result += ( + self.lora_dropout(x) + @ self.lora_A.transpose(0, 1) + @ self.lora_B.transpose(0, 1) + ) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + + +def compare_lora( + bsz: int, + seq_len: int, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float, +) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(16) + + # generate input tensor used by both implementations + x_input = torch.randn(bsz, seq_len, in_dim) + + # Initialize our implementation + lora = LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=rank, + alpha=alpha, + use_bias=True, + dropout=dropout, + ) + fixed_init_model(lora) + + with torch.no_grad(): + output = lora(x_input) + + # Initialize reference implementation + lora_ref = LoRALinearRef( + in_features=in_dim, + out_features=out_dim, + r=rank, + lora_alpha=alpha, + lora_dropout=dropout, + ) + + sd_mapping = { + "weight": "weight", + "bias": "bias", + "lora_a.weight": "lora_A", + "lora_b.weight": "lora_B", + } + mapped_sd = {sd_mapping.get(k): v for k, v in lora.state_dict().items()} + lora_ref.load_state_dict(mapped_sd) + with torch.no_grad(): + output_ref = lora_ref(x_input) + + print(output_ref.mean()) + torch.testing.assert_close(output_ref, output, atol=1e-6, rtol=1e-6) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare LoRA linear implementations") + parser.add_argument("--bsz", type=int, default=2, help="Batch size of input tensor") + parser.add_argument("--seq_len", type=int, default=32, help="Input sequence length") + parser.add_argument( + "--in_dim", + type=int, + default=64, + help="Input embedding dimension", + ) + parser.add_argument( + "--out_dim", + type=int, + default=128, + help="Input embedding dimension", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help="Rank of LoRA's A and B matrices", + ) + parser.add_argument( + "--alpha", + type=float, + default=1.0, + help="Scaling factor for LoRA matrices", + ) + parser.add_argument( + "--dropout", type=int, default=0.0, help="Dropout prob after linear layer" + ) + + args = parser.parse_args() + + compare_lora( + args.bsz, + args.seq_len, + args.in_dim, + args.out_dim, + args.rank, + args.alpha, + args.dropout, + ) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_lora_attention.py b/training/tests/torchtune/models/llama2/scripts/compare_lora_attention.py new file mode 100644 index 0000000..f4dc9b8 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_lora_attention.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List + +import torch + +from tests.test_utils import fixed_init_model + +from torch import nn + +from torchtune.models.llama2._lora_llama2_builders import _lora_llama_self_attention +from torchtune.modules import CausalSelfAttention, KVCache, RotaryPositionalEmbeddings + +try: + from peft import inject_adapter_in_model, LoraConfig +except: + raise ImportError("Must have peft installed to run this comparison script") + + +def compare_lora_attention( + bsz: int, + seq_len: int, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + lora_modules: List[str], + lora_rank: int, + lora_alpha: float, +) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(16) + + # generate input tensor used by both implementations + x = torch.randn(bsz, seq_len, embed_dim) + + # Our implementation + lora_llama_attn = _lora_llama_self_attention( + lora_modules=lora_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + fixed_init_model(lora_llama_attn) + + with torch.no_grad(): + out = lora_llama_attn(x) + + batch_size = None + attn_dropout = 0.0 + # Reference implementation: wrap our native causal self-attention with PEFT LoRAConfig + # Copy-pasted from llama2.py + # https://github.com/pytorch/torchtune/blob/e983194629d7f093257225dafb7cbc4e46505cc8/torchtune/models/llama2.py#L88-L114 + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + kv_cache = ( + KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + n_kv_heads=num_heads, + head_dim=head_dim, + ) + if batch_size is not None + else None + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + llama_attn_ref = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=kv_cache, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + lora_config_ref = LoraConfig( + lora_alpha=lora_alpha, + lora_dropout=0.0, + r=lora_rank, + bias="none", + target_modules=lora_modules, + ) + + lora_llama_attn_ref = inject_adapter_in_model(lora_config_ref, llama_attn_ref) + + all_keys = ["q_proj", "k_proj", "v_proj", "output_proj"] + + mapped_sd = {} + for key in all_keys: + if key in lora_modules: + mapped_sd[f"{key}.base_layer.weight"] = lora_llama_attn.state_dict()[ + f"{key}.weight" + ] + mapped_sd[f"{key}.lora_A.default.weight"] = lora_llama_attn.state_dict()[ + f"{key}.lora_a.weight" + ] + mapped_sd[f"{key}.lora_B.default.weight"] = lora_llama_attn.state_dict()[ + f"{key}.lora_b.weight" + ] + else: + mapped_sd[f"{key}.weight"] = lora_llama_attn.state_dict()[f"{key}.weight"] + + lora_llama_attn_ref.load_state_dict(mapped_sd) + + with torch.no_grad(): + out_ref = lora_llama_attn_ref(x) + + print(lora_modules, out.mean(), out_ref.mean(), out.shape, out_ref.shape) + + # output tensors should be similar + torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + test_cases = [ + ["q_proj", "v_proj"], + ["q_proj", "k_proj", "v_proj", "output_proj"], + ["k_proj"], + ] + for lora_modules in test_cases: + compare_lora_attention( + bsz=2, + seq_len=32, + embed_dim=64, + num_heads=4, + num_kv_heads=2, + max_seq_len=64, + lora_modules=lora_modules, + lora_rank=4, + lora_alpha=1.0, + ) diff --git a/training/tests/torchtune/models/llama2/scripts/compare_lora_llama2.py b/training/tests/torchtune/models/llama2/scripts/compare_lora_llama2.py new file mode 100644 index 0000000..cf3e532 --- /dev/null +++ b/training/tests/torchtune/models/llama2/scripts/compare_lora_llama2.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List + +import torch + +from tests.test_utils import fixed_init_model + +from torch import nn + +from torchtune.models.llama2 import get_lora_module_names, llama2, lora_llama2 + +try: + from peft import inject_adapter_in_model, LoraConfig +except: + raise ImportError("Must have peft installed to run this comparison script") + + +def compare_lora_llama2( + bsz: int, + seq_len: int, + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + lora_modules: List[str], + lora_in_mlp: bool, + lora_in_output: bool, + lora_rank: int, + lora_alpha: float, +) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(16) + + # generate input tensor used by both implementations + x = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len)) + + # Our implementation + lora_llama = lora_llama2( + lora_attn_modules=lora_modules, + apply_lora_to_mlp=lora_in_mlp, + apply_lora_to_output=lora_in_output, + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + max_seq_len=max_seq_len, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + # This is to make final outputs less trivial + lora_llama.norm = nn.Identity() + fixed_init_model(lora_llama) + + with torch.no_grad(): + out = lora_llama(x) + + # Reference implementation: wrap our native llama2 with PEFT LoRAConfig + llama_ref = llama2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + max_seq_len=max_seq_len, + ) + + peft_lora_modules = get_lora_module_names(lora_modules, lora_in_mlp, lora_in_output) + + lora_config_ref = LoraConfig( + lora_alpha=lora_alpha, + lora_dropout=0.0, + r=lora_rank, + bias="none", + target_modules=peft_lora_modules, + ) + + lora_llama_ref = inject_adapter_in_model(lora_config_ref, llama_ref) + lora_llama_ref.norm = nn.Identity() + + mapped_sd = {} + for k, v in lora_llama.state_dict().items(): + new_k = k.replace("lora_a", "lora_A.default").replace( + "lora_b", "lora_B.default" + ) + for attn_module in lora_modules: + if attn_module in new_k: + new_k = new_k.replace( + attn_module + ".weight", attn_module + ".base_layer.weight" + ) + if lora_in_mlp and any([f"mlp.w{i}.weight" in new_k for i in range(1, 4)]): + new_k = new_k.replace(".weight", ".base_layer.weight") + + if lora_in_output and "output.weight" in new_k: + new_k = new_k.replace(".weight", ".base_layer.weight") + + mapped_sd[new_k] = v + + lora_llama_ref.load_state_dict(mapped_sd) + + with torch.no_grad(): + out_ref = lora_llama_ref(x) + + print( + lora_modules, + lora_in_mlp, + lora_in_output, + out.mean(), + out_ref.mean(), + out.shape, + out_ref.shape, + ) + + # output tensors should be similar + torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + test_cases = [ + (["q_proj", "v_proj"], False, False), + (["q_proj", "k_proj", "v_proj", "output_proj"], True, False), + (["k_proj"], True, True), + ] + for lora_modules, lora_in_mlp, lora_in_output in test_cases: + compare_lora_llama2( + bsz=2, + seq_len=32, + vocab_size=50, + num_layers=3, + num_heads=4, + num_kv_heads=2, + embed_dim=64, + max_seq_len=64, + lora_modules=lora_modules, + lora_in_mlp=lora_in_mlp, + lora_in_output=lora_in_output, + lora_rank=4, + lora_alpha=1.0, + ) diff --git a/training/tests/torchtune/models/llama2/test_llama2_prompt_template.py b/training/tests/torchtune/models/llama2/test_llama2_prompt_template.py new file mode 100644 index 0000000..4095c9a --- /dev/null +++ b/training/tests/torchtune/models/llama2/test_llama2_prompt_template.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE +from torchtune.data import Message +from torchtune.models.llama2 import Llama2ChatTemplate + + +class TestLlama2ChatTemplate: + expected_dialogue = [ + Message( + role="user", + content="[INST] <>\nYou are an AI assistant. User will you give you a task. " + "Your goal is to complete the task as faithfully as you can. While performing " + "the task think step-by-step and justify your steps.\n<>\n\nPlease " + "briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " + "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " + "How about on an icy road? Well one father in Russia did just that, and recorded " + "the entire thing. To her credit, the child seemed to be doing a great job. " + "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" + "Summary: [/INST] ", + ), + Message( + role="assistant", + content="A father in Russia allowed his 8-year-old child to drive his car on an " + "icy road and recorded the event. The child appeared to be handling the situation well, " + "showcasing their driving skills despite the challenging conditions.", + ), + ] + + def test_call(self): + actual = Llama2ChatTemplate()(MESSAGE_SAMPLE) + assert_dialogue_equal(actual, self.expected_dialogue) diff --git a/training/tests/torchtune/models/llama2/test_llama2_tokenizer.py b/training/tests/torchtune/models/llama2/test_llama2_tokenizer.py new file mode 100644 index 0000000..9894686 --- /dev/null +++ b/training/tests/torchtune/models/llama2/test_llama2_tokenizer.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.data import Message +from torchtune.models.llama2 import llama2_tokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestLlama2Tokenizer: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + return llama2_tokenizer(str(ASSETS / "m.model")) + + def test_tokenize_messages(self, tokenizer): + messages = [ + Message( + role="user", + content="Below is an instruction that describes a task. Write a response " + "that appropriately completes the request.\n\n### Instruction:\nGenerate " + "a realistic dating profile bio.\n\n### Response:\n", + masked=True, + ), + Message( + role="assistant", + content="I'm an outgoing and friendly person who loves spending time with " + "friends and family. I'm also a big-time foodie and love trying out new " + "restaurants and different cuisines. I'm a big fan of the arts and enjoy " + "going to museums and galleries. I'm looking for someone who shares my " + "interest in exploring new places, as well as someone who appreciates a " + "good conversation over coffee.", + ), + ] + tokens, mask = tokenizer.tokenize_messages(messages) + expected_tokens = [ + 1, + 323, + 418, + 202, + 31, + 128, + 15, + 120, + 47, + 88, + 584, + 23, + 1665, + 182, + 9, + 434, + 295, + 85, + 4, + 780, + 47, + 636, + 9, + 1094, + 213, + 23, + 9, + 69, + 69, + 164, + 1153, + 299, + 35, + 961, + 132, + 237, + 7, + 5, + 761, + 4, + 12, + 0, + 313, + 120, + 47, + 88, + 584, + 166, + 493, + 171, + 54, + 299, + 9, + 906, + 244, + 19, + 186, + 767, + 303, + 671, + 92, + 209, + 24, + 190, + 52, + 38, + 4, + 12, + 0, + 1243, + 7, + 69, + 135, + 213, + 166, + 6, + 21, + 45, + 128, + 71, + 58, + 38, + 14, + 10, + 652, + 35, + 462, + 101, + 1306, + 7, + 341, + 171, + 20, + 14, + 127, + 26, + 652, + 7, + 10, + 1268, + 4, + 6, + 21, + 45, + 591, + 9, + 566, + 22, + 994, + 913, + 38, + 20, + 52, + 24, + 10, + 1306, + 734, + 14, + 71, + 365, + 1382, + 7, + 10, + 801, + 105, + 88, + 244, + 985, + 7, + 4, + 6, + 21, + 45, + 9, + 566, + 126, + 180, + 11, + 5, + 1137, + 7, + 10, + 1089, + 151, + 8, + 1156, + 213, + 342, + 7, + 10, + 384, + 104, + 54, + 470, + 4, + 6, + 21, + 45, + 287, + 14, + 33, + 125, + 135, + 24, + 101, + 512, + 66, + 7, + 28, + 822, + 15, + 542, + 69, + 59, + 110, + 14, + 365, + 229, + 7, + 3, + 36, + 267, + 36, + 125, + 135, + 24, + 101, + 1503, + 182, + 9, + 222, + 1661, + 191, + 332, + 92, + 92, + 24, + 24, + 4, + 2, + ] + expected_mask = [True] * 75 + [False] * 125 + assert expected_tokens == tokens + assert expected_mask == mask diff --git a/training/tests/torchtune/models/llama2/test_lora_llama2.py b/training/tests/torchtune/models/llama2/test_lora_llama2.py new file mode 100644 index 0000000..4dcdaa0 --- /dev/null +++ b/training/tests/torchtune/models/llama2/test_lora_llama2.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy + +import pytest +import torch + +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn +from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune import utils +from torchtune.models.llama2 import llama2, lora_llama2 +from torchtune.models.llama2._component_builders import lora_llama2_self_attention +from torchtune.modules.low_precision import FrozenNF4Linear +from torchtune.modules.peft import LoRALinear +from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt +from torchtune.utils.seed import set_seed + +RANK = 4 +ALPHA = 1.0 +BSZ = 2 +SEQ_LEN = 32 +EMBED_DIM = 64 +NUM_HEADS = 4 +NUM_KV_HEADS = 2 +MAX_SEQ_LEN = 64 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestLoRALlamaSelfAttention: + @pytest.fixture + def inputs(self) -> torch.Tensor: + inputs = torch.randn(BSZ, SEQ_LEN, EMBED_DIM) + return inputs + + def get_lora_llama_self_attention(self, lora_modules): + lora_llama_sa = lora_llama2_self_attention( + lora_modules=lora_modules, + embed_dim=EMBED_DIM, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + max_seq_len=MAX_SEQ_LEN, + lora_rank=RANK, + lora_alpha=ALPHA, + ) + fixed_init_model(lora_llama_sa) + return lora_llama_sa + + def test_empty_lora_modules(self): + with pytest.raises(ValueError, match="Must pass one or more of"): + _ = self.get_lora_llama_self_attention([]) + + @pytest.mark.parametrize( + "lora_modules, expected", + [ + (["q_proj", "v_proj"], torch.tensor(51.3152)), + (["q_proj", "k_proj", "v_proj", "output_proj"], torch.tensor(79.8887)), + (["k_proj"], torch.tensor(45.9261)), + ], + ) + def test_forward(self, inputs, lora_modules, expected): + lora_llama_sa = self.get_lora_llama_self_attention(lora_modules) + actual = lora_llama_sa(inputs) + assert_expected(actual.shape, (BSZ, SEQ_LEN, EMBED_DIM)) + assert_expected(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + +class TestLoRALlama2: + @pytest.fixture + def vocab_size(self): + return 50 + + @pytest.fixture + def inputs(self, vocab_size): + return torch.randint(low=0, high=vocab_size, size=(BSZ, SEQ_LEN)) + + def get_lora_llama2( + self, + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + vocab_size, + reset_norm=True, + quantize_base=False, + embed_dim=EMBED_DIM, + dtype=None, + ): + num_layers = 3 + model = lora_llama2( + lora_attn_modules=lora_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + max_seq_len=MAX_SEQ_LEN, + lora_rank=RANK, + lora_alpha=ALPHA, + quantize_base=quantize_base, + ) + # To make final outputs less trivial + if reset_norm: + model.norm = nn.Identity() + + # dtype=None means to just read dtype from parameters + # in the model. This dtype is set explicitly to bf16 currently + # when initializing QLoRA models, as ops such as `arange` aren't + # yet supported with the actual nf4 tensor dtype yet. + fixed_init_model(model, dtype=dtype) + + return model + + def get_ref_llama2(self, vocab_size, embed_dim=EMBED_DIM): + num_layers = 3 + model = llama2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + max_seq_len=MAX_SEQ_LEN, + ) + return model + + @pytest.mark.parametrize( + "lora_modules, apply_lora_to_mlp, apply_lora_to_output, expected", + [ + (["q_proj", "v_proj"], False, False, torch.tensor(5638859.0)), + ( + ["q_proj", "k_proj", "v_proj", "output_proj"], + True, + False, + torch.tensor(21187608.0), + ), + (["k_proj"], True, True, torch.tensor(32438764.0)), + ], + ) + def test_forward( + self, + vocab_size, + inputs, + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + expected, + ): + model = self.get_lora_llama2( + lora_modules, apply_lora_to_mlp, apply_lora_to_output, vocab_size + ) + actual = model(inputs) + assert_expected(actual.shape, (BSZ, SEQ_LEN, vocab_size)) + assert_expected(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + @pytest.mark.parametrize( + "lora_modules, apply_lora_to_mlp, apply_lora_to_output", + [ + (["q_proj", "v_proj"], True, False), + (["q_proj", "k_proj", "v_proj", "output_proj"], False, False), + (["k_proj"], True, True), + ], + ) + def test_lora_llama2_state_dict_parity( + self, lora_modules, apply_lora_to_mlp, apply_lora_to_output, vocab_size + ): + lora_llama = self.get_lora_llama2( + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + vocab_size, + reset_norm=False, + ) + ref_llama = self.get_ref_llama2(vocab_size) + # Ensure ref_llama state_dict can be loaded into lora_llama with only "lora" + # keys missing. + ref_llama_state_dict = ref_llama.state_dict() + missing, unexpected = lora_llama.load_state_dict( + ref_llama_state_dict, strict=False + ) + assert not unexpected + assert all(["lora" in key for key in missing]) + + def test_qlora_linear_quantize_base(self): + model = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + # quantize_base + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=torch.bfloat16, + ) + for module in model.modules(): + if isinstance(module, LoRALinear): + assert module._quantize_base + + def test_qlora_linear_quantize_base_weights(self): + # this test checks that modules that don't have LoRA applied to them + # have their base weights quantized + model = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj"], + apply_lora_to_mlp=True, + # quantize_base + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=torch.bfloat16, + ) + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + assert module._quantize_base + elif name in ["k_proj", "output_proj"]: + assert isinstance(module, FrozenNF4Linear) + assert isinstance(module.weight, NF4Tensor) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_llama2_parity(self, dtype, inputs): + with utils.set_default_dtype(dtype): + model_ref = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=False, + embed_dim=512, + dtype=dtype, + ) + qlora = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=dtype, + ) + qlora_sd = qlora.state_dict() + model_ref.load_state_dict(qlora_sd) + # Forward pass of model_ref and qlora should be the same, as QLoRA linear layers should use + # a special linear operator that runs the compute in bf16, but only saves the 4 bit tensors + # for backward. + ref_output = model_ref(inputs) + output = qlora(inputs) + torch.testing.assert_close(ref_output, output) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_llama2_state_dict(self, dtype): + with utils.set_default_dtype(dtype): + model_ref = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=False, + embed_dim=512, + dtype=dtype, + ) + high_prec_sd = model_ref.state_dict() + for v in high_prec_sd.values(): + assert v.dtype == dtype + + # ensure quantized LoRA can load a bf16 state_dict + qlora = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=dtype, + ) + qlora.load_state_dict(high_prec_sd) + # LoRALinear base weights should be nf4 still + for module in qlora.modules(): + if isinstance(module, LoRALinear): + assert isinstance(module.weight, NF4Tensor) + # saved state_dict should have bf16 weights. + qlora_sd = qlora.state_dict() + for v in qlora_sd.values(): + assert v.dtype == dtype + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_llama2_merged_state_dict(self, dtype): + with utils.set_default_dtype(dtype): + qlora = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=dtype, + reset_norm=False, # to ensure norm.scale key exists + ) + + qlora_sd = qlora.state_dict() + # Ensure checkpoint merging produces bf16 tensors + merged_ckpt = get_merged_lora_ckpt(deepcopy(qlora_sd), rank=RANK, alpha=ALPHA) + for v in merged_ckpt.values(): + # paranoid check for both, as NF4Tensor had issue where NF4Tensor.dtype would return bf16 + assert not isinstance(v, NF4Tensor) + assert v.dtype == dtype + + # Ensure checkpoint can be loaded into non-LoRA model + with utils.set_default_dtype(dtype): + llama2 = self.get_ref_llama2(vocab_size=50, embed_dim=512) + + llama2.load_state_dict(merged_ckpt) diff --git a/training/tests/torchtune/models/llama3/test_llama3.py b/training/tests/torchtune/models/llama3/test_llama3.py new file mode 100644 index 0000000..190eaf4 --- /dev/null +++ b/training/tests/torchtune/models/llama3/test_llama3.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.llama3 import llama3 +from torchtune.utils.seed import set_seed + +EMBED_DIM = 128 +NUM_LAYERS = 4 +NUM_HEADS = 16 +NUM_KV_HEADS = 8 +VOCAB_SIZE = 32000 +MAX_SEQ_LEN = 2048 +BSZ = 2 +SEQ_LEN = 100 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestLlama3: + @pytest.fixture + def inputs(self): + return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) + + def test_forward(self, inputs): + model = llama3( + vocab_size=VOCAB_SIZE, + num_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + fixed_init_model(model, min_val=-0.25, max_val=0.5) + actual = model(inputs) + expected = torch.tensor(3.9763) + assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) diff --git a/training/tests/torchtune/models/llama3/test_llama3_tokenizer.py b/training/tests/torchtune/models/llama3/test_llama3_tokenizer.py new file mode 100644 index 0000000..f1d2ad9 --- /dev/null +++ b/training/tests/torchtune/models/llama3/test_llama3_tokenizer.py @@ -0,0 +1,414 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.data._messages import Message +from torchtune.models.llama3 import llama3_tokenizer, Llama3Tokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestLlama3Tokenizer: + @pytest.fixture + def tokenizer(self): + # Pretrained tiktoken model generated via the script in + # https://gist.github.com/ebsmothers/54b133dd87db6679b14318545aaa2de4 + return llama3_tokenizer( + path=str(ASSETS / "tiktoken_small.model"), + ) + + @pytest.fixture + def user_text_a(self): + return "I can see the sun. " + + @pytest.fixture + def user_text_b(self): + return "But even if I cannot see the sun, I know that it exists." + + @pytest.fixture + def assistant_text(self): + return "And to know that the sun is there - that is living." + + @pytest.fixture + def user_text_message(self, user_text_a, user_text_b): + message = Message( + role="user", + content=user_text_a + user_text_b, + masked=True, + eot=True, + ) + expected_tokens = [ + 128000, + 128006, + 477, + 273, + 128007, + 10, + 10, + 73, + 503, + 654, + 262, + 376, + 110, + 46, + 690, + 720, + 428, + 270, + 1119, + 654, + 262, + 376, + 110, + 44, + 270, + 686, + 334, + 312, + 522, + 511, + 115, + 46, + 128009, + ] + return message, expected_tokens + + @pytest.fixture + def assistant_text_message(self, assistant_text): + message = Message( + role="assistant", + content=assistant_text, + masked=False, + eot=True, + ) + expected_tokens = [ + 128006, + 520, + 511, + 446, + 128007, + 10, + 10, + 65, + 269, + 277, + 686, + 334, + 262, + 376, + 110, + 351, + 443, + 32, + 45, + 334, + 351, + 1955, + 46, + 128009, + 128001, + ] + return message, expected_tokens + + @pytest.fixture + def user_image_text_message(self, user_text_a, user_text_b): + message = Message( + role="user", + content=[ + {"type": "image"}, + {"type": "text", "content": user_text_a + user_text_b}, + ], + masked=True, + eot=True, + ) + expected_tokens = [ + 128000, + 128006, + 477, + 273, + 128007, + 10, + 10, + 128011, + 73, + 503, + 654, + 262, + 376, + 110, + 46, + 690, + 720, + 428, + 270, + 1119, + 654, + 262, + 376, + 110, + 44, + 270, + 686, + 334, + 312, + 522, + 511, + 115, + 46, + 128009, + ] + return message, expected_tokens + + @pytest.fixture + def user_interleaved_image_text_message(self, user_text_a, user_text_b): + message = Message( + role="user", + content=[ + {"type": "image"}, + {"type": "text", "content": user_text_a}, + {"type": "image"}, + {"type": "text", "content": user_text_b}, + ], + masked=True, + eot=True, + ) + expected_tokens = [ + 128000, + 128006, + 477, + 273, + 128007, + 10, + 10, + 128011, + 73, + 503, + 654, + 262, + 376, + 110, + 46, + 128011, + 1542, + 720, + 428, + 270, + 1119, + 654, + 262, + 376, + 110, + 44, + 270, + 686, + 334, + 312, + 522, + 511, + 115, + 46, + 128009, + ] + return message, expected_tokens + + @pytest.fixture + def assistant_tool_message(self): + message = Message( + role="assistant", + content=[ + {"type": "text", "content": "locate_sun(radius=100_000_000)"}, + ], + masked=False, + ipython=True, + eot=False, + ) + expected_tokens = [ + 128006, + 520, + 511, + 446, + 128007, + 10, + 10, + 128010, + 525, + 99, + 534, + 95, + 115, + 433, + 40, + 114, + 338, + 105, + 477, + 61, + 49, + 1635, + 95, + 1635, + 48, + 95, + 1635, + 48, + 41, + 128008, + ] + return message, expected_tokens + + @pytest.fixture + def ipython_message(self): + message = Message( + role="ipython", + content=[ + {"type": "text", "content": '{"content": True}'}, + ], + masked=True, + eot=False, + ) + expected_tokens = [ + 128006, + 1558, + 121, + 483, + 279, + 128007, + 10, + 10, + 123, + 34, + 99, + 957, + 317, + 34, + 58, + 323, + 114, + 979, + 125, + 128008, + ] + return message, expected_tokens + + def test_token_ids(self, tokenizer): + assert tokenizer.bos_id == 128000 + assert tokenizer.eos_id == 128001 + assert tokenizer.pad_id == 128004 + assert tokenizer.step_id == 128005 + assert tokenizer.start_header_id == 128006 + assert tokenizer.end_header_id == 128007 + assert tokenizer.eom_id == 128008 + assert tokenizer.eot_id == 128009 + assert tokenizer.python_tag == 128010 + assert tokenizer.image_id == 128011 + + def test_tokenizer_vocab_size(self, tokenizer): + assert tokenizer.base_vocab_size == 2000 + assert tokenizer.vocab_size == 128256 + + def test_tokenize_text_messages( + self, tokenizer, user_text_message, assistant_text_message + ): + text_messages = [user_text_message[0], assistant_text_message[0]] + expected_tokens = user_text_message[1] + assistant_text_message[1] + expected_mask = ( + [True] * len(user_text_message[1]) + + [False] * (len(assistant_text_message[1]) - 1) + + [True] + ) + tokens, mask = tokenizer.tokenize_messages(text_messages) + assert tokens == expected_tokens + assert mask == expected_mask + + def test_tokenize_image_and_text_messages( + self, tokenizer, user_image_text_message, assistant_text_message + ): + image_and_text_messages = [ + user_image_text_message[0], + assistant_text_message[0], + ] + expected_tokens = user_image_text_message[1] + assistant_text_message[1] + expected_mask = ( + [True] * len(user_image_text_message[1]) + + [False] * (len(assistant_text_message[1]) - 1) + + [True] + ) + tokens, mask = tokenizer.tokenize_messages(image_and_text_messages) + assert tokens == expected_tokens + assert mask == expected_mask + + def test_tokenize_interleaved_image_and_text_messages( + self, + tokenizer, + user_interleaved_image_text_message, + assistant_text_message, + ): + interleaved_image_and_text_messages = [ + user_interleaved_image_text_message[0], + assistant_text_message[0], + ] + expected_tokens = ( + user_interleaved_image_text_message[1] + assistant_text_message[1] + ) + expected_mask = ( + [True] * len(user_interleaved_image_text_message[1]) + + [False] * (len(assistant_text_message[1]) - 1) + + [True] + ) + tokens, mask = tokenizer.tokenize_messages(interleaved_image_and_text_messages) + assert tokens == expected_tokens + assert mask == expected_mask + + def test_tokenize_tool_call_messages( + self, + tokenizer, + user_text_message, + assistant_tool_message, + ipython_message, + assistant_text_message, + ): + tool_call_messages = [ + user_text_message[0], + assistant_tool_message[0], + ipython_message[0], + assistant_text_message[0], + ] + expected_tokens = ( + user_text_message[1] + + assistant_tool_message[1] + + ipython_message[1] + + assistant_text_message[1] + ) + expected_mask = ( + [True] * len(user_text_message[1]) + + [False] * len(assistant_tool_message[1]) + + [True] * len(ipython_message[1]) + + [False] * (len(assistant_text_message[1]) - 1) + + [True] + ) + tokens, mask = tokenizer.tokenize_messages(tool_call_messages) + assert tokens == expected_tokens + assert mask == expected_mask + + def test_validate_special_tokens(self): + with pytest.raises( + ValueError, match="<|begin_of_text|> missing from special_tokens" + ): + _ = Llama3Tokenizer( + path=str(ASSETS / "tiktoken_small.model"), + # Same as LLAMA3_SPECIAL_TOKENS but one missing + special_tokens={ + "<|end_of_text|>": 128001, + "<|start_header_id|>": 128006, + "<|end_header_id|>": 128007, + "<|eot_id|>": 128009, + "<|eom_id|>": 128008, + "<|python_tag|>": 128255, + }, + ) diff --git a/training/tests/torchtune/models/llama3_1/test_position_embeddings.py b/training/tests/torchtune/models/llama3_1/test_position_embeddings.py new file mode 100644 index 0000000..c4e7dad --- /dev/null +++ b/training/tests/torchtune/models/llama3_1/test_position_embeddings.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from tests.test_utils import assert_expected +from torch import tensor + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestLlama3ScaledRoPE: + """ + Class for testing our Scaled RoPE for LLama3.1 (RoPE) + implementation. The expected tensors are computed from the + reference implementation here: + https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L272 + + The expected values are computed using the following code: + https://gist.github.com/joecummings/4f1331a9c1e5aa15bad1641acb74fe0e + """ + + EXPECTED_FREQS_CIS_MEAN = tensor(0.1738) + EXPECTED_FREQS_CIS_SUM = tensor(91141.7656) + EXPECTED_FREQS_CIS_MAX = tensor(1.0) + + EXPECTED_X_OUT_MEAN = tensor(-2.4781e-06) + EXPECTED_X_OUT_SUM = tensor(-83.1523) + EXPECTED_X_OUT_MAX = tensor(5.4625) + + @pytest.fixture + def input_params(self): + bsz = 4 + num_heads = 32 + embed_dim = 4096 + head_dim = embed_dim // num_heads + seq_len = 2048 + max_seq_len = 4096 + return bsz, num_heads, head_dim, seq_len, max_seq_len + + @pytest.fixture + def input(self, input_params) -> tensor: + bsz, num_heads, head_dim, seq_len, _ = input_params + return torch.randn(bsz, seq_len, num_heads, head_dim) + + @pytest.fixture + def rope(self, input_params) -> Llama3ScaledRoPE: + _, _, head_dim, _, max_seq_len = input_params + return Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len) + + def test_cache_equality(self, input, rope) -> None: + # Have to explicitly call _rope_init() to initialize theta matrix + rope._rope_init() + cache = rope.cache + + assert_expected(cache.mean(), self.EXPECTED_FREQS_CIS_MEAN, atol=1e-4) + assert_expected(cache.sum(), self.EXPECTED_FREQS_CIS_SUM, atol=1e-4) + assert_expected(cache.max(), self.EXPECTED_FREQS_CIS_MAX) + + def test_forward(self, input, rope) -> None: + x_out = rope(input) + + # check the numerics of the computed tensor + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_forward_with_curr_pos(self, input, rope) -> None: + ( + _, + seq_len, + _, + _, + ) = input.shape + x_out = rope(input, input_pos=torch.arange(seq_len)) + + # these values should be exactly the same as test_forward + # since in this case input_pos covers the entire input + # sequence. This tests that input_pos works as expected i.e. + # extracts the embeddings for the relevant positions + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_forward_with_2d_pos_ids(self, input, rope) -> None: + """ + Use input_pos to indicate positions of each token relative to its sequence + when sample is packed. + """ + ( + bsz, + seq_len, + _, + _, + ) = input.shape + x_out = rope( + input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len) + ) + + # these values should be exactly the same as test_forward + # AND test_forward_with_current_pos. In this case input_pos + # covers the entire batch dim and is defined for each sample separately. + # This tests that input_pos works as expected i.e. + # extracts the embeddings for the relevant positions for each sample + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_rope_init_meta_device(self, input_params): + _, _, head_dim, _, max_seq_len = input_params + rope_on_device = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len) + with torch.device("meta"): + meta_rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len) + + meta_rope._rope_init() + for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()): + torch.testing.assert_close(p1, p2) + + # Assert meta_rope cache is no longer on meta device + assert meta_rope.cache.device != torch.device("meta") diff --git a/training/tests/torchtune/models/mistral/scripts/README.md b/training/tests/torchtune/models/mistral/scripts/README.md new file mode 100644 index 0000000..6b61ffa --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/README.md @@ -0,0 +1,13 @@ +## Verifying correctness +This directory compares the current implementation of `mistral` to the reference implementation at https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py. Additionally, `torchtune.models.mistral._component_builders.mistral_mlp` is compared in `tests.torchtune.models.mistral.scripts.compare_feed_forward.py` + +Since `torchtune.models.mistral` shares nearly all components with `torchtune.models.llama2`, please see `tests.torchtune.models.llama2.scripts` for comparison scripts for individual components. + +## Running the scripts + +You can run the scripts using the following command as an example. +Each script should print out the value being used in the associated unit tests. + +``` +python3 -m tests.torchtune.models.mistral.scripts.compare_mistral +``` diff --git a/training/tests/torchtune/models/mistral/scripts/__init__.py b/training/tests/torchtune/models/mistral/scripts/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/models/mistral/scripts/compare_feed_forward.py b/training/tests/torchtune/models/mistral/scripts/compare_feed_forward.py new file mode 100644 index 0000000..baae85b --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/compare_feed_forward.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from tests.test_utils import fixed_init_model +from tests.torchtune.models.mistral.scripts.mistral_reference import FeedForward + +from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig + +from torchtune.models.mistral._component_builders import mistral_mlp + + +def compare_feed_forward(embed_dim: int, intermediate_dim: int) -> None: + + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(MistralTestConfig.SEED) + + # generate input tensor used by both implementations + input_t = torch.randn(1, embed_dim) + + # reference implementation + ff_ref = FeedForward(dim=embed_dim, hidden_dim=intermediate_dim) + fixed_init_model(ff_ref) + + with torch.no_grad(): + ff_out_ref = ff_ref(input_t) + + ff = mistral_mlp(embed_dim, intermediate_dim) + fixed_init_model(ff) + + with torch.no_grad(): + ff_out = ff(input_t) + + torch.testing.assert_close(ff_out, ff_out_ref, atol=1e-5, rtol=1e-5) + print(f"ff_out.mean(): {ff_out.mean()}") + print(f"ff_out.max(): {ff_out.max()}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare FeedForward implementations") + parser.add_argument( + "--embed_dim", + type=int, + default=MistralTestConfig.EMBED_DIM, + help="Embedding dimension for self-attention", + ) + parser.add_argument( + "--intermediate_dim", + type=int, + default=MistralTestConfig.INTERMEDIATE_DIM, + help="Intermediate dimension for MLP", + ) + + args = parser.parse_args() + + compare_feed_forward(args.embed_dim, args.intermediate_dim) diff --git a/training/tests/torchtune/models/mistral/scripts/compare_mistral.py b/training/tests/torchtune/models/mistral/scripts/compare_mistral.py new file mode 100644 index 0000000..a861a2f --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/compare_mistral.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from tests.test_utils import fixed_init_model +from tests.torchtune.models.mistral.scripts.mistral_reference import Transformer +from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig + +from torchtune.models.mistral import mistral + + +def compare_decoder( + bsz: int, + vocab_size: int, + seq_len: int, + embed_dim: int, + intermediate_dim: int, + n_layers: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + rope_base: int, + norm_eps: float, +) -> None: + # make sure we have the right seed for generating outputs + # this should match up the seed value set in the corresponding + # unit test + torch.manual_seed(MistralTestConfig.SEED) + + head_dim = embed_dim // num_heads + + # generate input tensor used by both implementations + x_input = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len)) + + # current implementation; initialize with constant to compare outputs + mistral_model = mistral( + vocab_size=vocab_size, + embed_dim=embed_dim, + num_layers=n_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + intermediate_dim=intermediate_dim, + norm_eps=norm_eps, + rope_base=rope_base, + ) + fixed_init_model(mistral_model) + + with torch.no_grad(): + mistral_model_out = mistral_model(x_input) + + # initialize reference implementation with constant weights + ref_mistral_model = Transformer( + vocab_size=vocab_size, + n_layers=n_layers, + n_heads=num_heads, + head_dim=head_dim, + dim=embed_dim, + n_kv_heads=num_kv_heads, + hidden_dim=intermediate_dim, + max_seq_len=max_seq_len, + rope_base=rope_base, + norm_eps=norm_eps, + ) + + mapped_sd = {} + for k, v in mistral_model.state_dict().items(): + new_k = k.replace("attn", "attention") + new_k = ( + new_k.replace("q_proj", "wq") + .replace("k_proj", "wk") + .replace("v_proj", "wv") + .replace("output_proj", "wo") + ) + new_k = new_k.replace("mlp", "feed_forward") + new_k = new_k.replace("feed_forward_norm.scale", "ffn_norm.weight") + new_k = new_k.replace("sa_norm.scale", "attention_norm.weight") + + new_k = new_k.replace("norm.scale", "norm.weight") + mapped_sd[new_k] = v + + ref_mistral_model.load_state_dict(mapped_sd) + + with torch.no_grad(): + red_mistral_model_out = ref_mistral_model(x_input, torch.arange(seq_len)) + + # # value: torch.tensor(18.2749) + print(f"mistral_model_out.mean(): {mistral_model_out.mean()}") + print(f"red_mistral_model_out.mean(): {red_mistral_model_out.mean()}") + + torch.testing.assert_close( + mistral_model_out, red_mistral_model_out, atol=1e-2, rtol=1e-2 + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare Decoder implementations") + parser.add_argument( + "--bsz", + type=int, + default=MistralTestConfig.BSZ, + help="Batch size of input tensor", + ) + parser.add_argument( + "--seq_len", + type=int, + default=MistralTestConfig.SEQ_LEN, + help="input sequence length", + ) + parser.add_argument( + "--vocab_size", + type=int, + default=MistralTestConfig.VOCAB_SIZE, + help="vocab size", + ) + parser.add_argument( + "--embed_dim", + type=int, + default=MistralTestConfig.EMBED_DIM, + help="Embedding dimension used to compute the dim for RopE", + ) + parser.add_argument( + "--intermediate_dim", + type=int, + default=MistralTestConfig.INTERMEDIATE_DIM, + help="Intermediate dimension for MLP", + ) + parser.add_argument( + "--num_layers", + type=int, + default=MistralTestConfig.NUM_LAYERS, + help="number of transformer layers", + ) + parser.add_argument( + "--num_heads", + type=int, + default=MistralTestConfig.NUM_HEADS, + help="Number of heads in the attention layer", + ) + parser.add_argument( + "--num_kv_heads", + type=int, + default=MistralTestConfig.NUM_KV_HEADS, + help="Number of key/value heads in the attention layer", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=MistralTestConfig.MAX_SEQ_LEN, + help="max sequence length", + ) + parser.add_argument( + "--norm_eps", + type=float, + default=MistralTestConfig.NORM_EPS, + help="RMSNorm epsilon", + ) + parser.add_argument( + "--rope_base", + type=float, + default=MistralTestConfig.ROPE_BASE, + help="Base for the rotary positional embeddings", + ) + args = parser.parse_args() + + compare_decoder( + args.bsz, + args.vocab_size, + args.seq_len, + args.embed_dim, + args.intermediate_dim, + args.num_layers, + args.num_heads, + args.num_kv_heads, + args.max_seq_len, + args.rope_base, + args.norm_eps, + ) diff --git a/training/tests/torchtune/models/mistral/scripts/compare_mistral_classifier.py b/training/tests/torchtune/models/mistral/scripts/compare_mistral_classifier.py new file mode 100644 index 0000000..11b608b --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/compare_mistral_classifier.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from tests.test_utils import fixed_init_model +from torch import nn +from torchtune.models.mistral import mistral_classifier +from torchtune.models.mistral._component_builders import mistral_mlp +from torchtune.modules import ( + CausalSelfAttention, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) + + +# Copying our mistral implementation here to allow access to `output_proj` +def mistral( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + output_proj: nn.Linear, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, +) -> TransformerDecoder: + """ + Build the decoder associated with the mistral model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + This does NOT currently include inference-time optimizations such as + sliding-window attention + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + + Returns: + TransformerDecoder: Instantiation of mistral model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def compare_mistral_classifier( + bsz: int, + seq_len: int, + num_classes: int, + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, +): + + # setting up the right seed for generating outputs + torch.manual_seed(16) + + # generate input tensor to be used by both implementations + x = torch.randint(low=0, high=vocab_size, size=(bsz, seq_len)) + + # our implementation + classifier = mistral_classifier( + num_classes=num_classes, + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + intermediate_dim=intermediate_dim, + max_seq_len=max_seq_len, + ) + fixed_init_model(classifier) + + with torch.no_grad(): + out = classifier(x) + + # reference implementation: manually specify nn.Linear after base mistral + output_proj = nn.Linear(embed_dim, num_classes, bias=False) + classifier_ref = mistral( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + intermediate_dim=intermediate_dim, + max_seq_len=max_seq_len, + output_proj=output_proj, + ) + + fixed_init_model(classifier_ref) + + with torch.no_grad(): + out_ref = classifier_ref(x) + + print( + f"output layer: {classifier.output}\n reference output layer: {classifier_ref.output}" + ) + print(f"output mean: {out.mean()}\n reference output mean: {out_ref.mean()}") + print(f"output shape: {out.shape}\n reference output shape: {out_ref.shape}") + + # output tensors should be similar within precision tolerance + torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-3) + assert out.shape == (bsz, seq_len, num_classes) + + +if __name__ == "__main__": + # (bsz, embed_dim, seq_len, n_classes) # expected + test_cases = [ + (2, 64, 64, 2), # 22.6879 + (64, 128, 256, 200), # 36.8238 + (1, 256, 512, 1), # 110.2561 + ] + for bsz, embed_dim, seq_len, n_classes in test_cases: + compare_mistral_classifier( + bsz, + seq_len, + n_classes, + vocab_size=32000, + num_layers=4, + num_heads=16, + num_kv_heads=8, + embed_dim=embed_dim, + intermediate_dim=512, + max_seq_len=2048, + ) diff --git a/training/tests/torchtune/models/mistral/scripts/mistral_reference.py b/training/tests/torchtune/models/mistral/scripts/mistral_reference.py new file mode 100644 index 0000000..79134aa --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/mistral_reference.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from torch import nn + +""" +Reference mistral implementation from the official repo: +https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py + +Components are copied here with minimal modifications. +""" + + +""" +Reference implementation of Attention from: +https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py + +Note, there's another implementation in the same repo which uses xformers for attention: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/mistral/model.py#L60 + +The implementation for this test uses `one_file_ref.py` since the xformers attention implementation +expects the input `[b, s, ...]` to be flattened `[b * s, ...]` which makes comparison difficult. + +Replicating code here to minimize dependencies. The code is modified to +remove dependencies from xformers and features like KV Caching. +""" + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): + keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) + values = torch.repeat_interleave(values, repeats=repeats, dim=2) + return keys, values + + +class Attention(nn.Module): + def __init__(self, n_heads: int, head_dim: int, dim: int, n_kv_heads: int): + super().__init__() + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = head_dim**-0.5 + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # removed positions as it was only used for cache retrieval + bsz, seqlen, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + key, value = repeat_kv(xk, xv, self.repeats) + + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # scores : [bsz, n_heads, seqlen | 1, seqlen] + scores = torch.matmul(query, key.transpose(2, 3)) * self.scale + print(scores.mean()) + if mask is not None: + scores += mask[None, None, ...] + print(scores.mean()) + scores = scores.float() + scores = nn.functional.softmax(scores, dim=-1).type_as(query) + output = torch.matmul(scores, value) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + return self.wo(output) + + +""" +Reference implementation of RoPE from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L47 + +The original code structures this as stand-alone functions instead of +a class. Replicating code here to minimize dependencies. +""" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +""" +Reference impementation of FeedForward from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L152 + +The original code structures this as stand-alone functions in +`torchtune.models.mistral._component_builders.mistral_mlp` instead of +a standalone class. +""" + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +""" +Reference implementation of TransformerBlock from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L190 +""" + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class TransformerBlock(nn.Module): + def __init__( + self, + n_heads: int, + head_dim: int, + dim: int, + n_kv_heads: int, + hidden_dim: int, + norm_eps: float, + ): + super().__init__() + self.n_heads = n_heads + self.dim = dim + self.attention = Attention( + n_heads=n_heads, head_dim=head_dim, dim=dim, n_kv_heads=n_kv_heads + ) + self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim) + self.attention_norm = RMSNorm(dim=dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor] + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), freqs_cis, mask) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +""" +Reference implementation of Transformer from: +https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/one_file_ref.py#L217 +""" + + +class Transformer(nn.Module): + def __init__( + self, + vocab_size: int, + n_layers: int, + n_heads: int, + head_dim: int, + dim: int, + n_kv_heads: int, + hidden_dim: int, + max_seq_len: int, + rope_base: int, + norm_eps: float, + ): + super().__init__() + self.vocab_size = vocab_size + self.n_layers = n_layers + assert self.vocab_size > 0 + + self.tok_embeddings = nn.Embedding(vocab_size, dim) + + self.layers = torch.nn.ModuleList( + [ + TransformerBlock( + n_heads=n_heads, + head_dim=head_dim, + dim=dim, + n_kv_heads=n_kv_heads, + hidden_dim=hidden_dim, + norm_eps=norm_eps, + ) + for _ in range(n_layers) + ] + ) + + self.norm = RMSNorm(dim, eps=norm_eps) + + self.output = nn.Linear(dim, vocab_size, bias=False) + + # our RoPE implementation is a bit different from the reference: + # mistral hardcodes max_seq_len and uses a `positions` argument + # in forward to index `freqs_cis` for the current sequence length + # before using it in the attention layer. + + self.freqs_cis = precompute_freqs_cis( + head_dim, max_seq_len * 2, theta=rope_base + ) # removed .to("cuda") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor): + _, seqlen = input_ids.shape + h = self.tok_embeddings(input_ids) + freqs_cis = self.freqs_cis[positions] + mask: Optional[torch.Tensor] = None + if input_ids.shape[1] > 1: + seqlen = input_ids.shape[1] + tensor = torch.full( + (seqlen, seqlen), + dtype=h.dtype, + fill_value=1, + device=h.device, + ) + mask = torch.tril(tensor, diagonal=0).to(h.dtype) + # removed mask banding + mask = torch.triu(mask, diagonal=-1) # setting sliding window to 1 + mask = torch.log(mask) + for layer in self.layers: + h = layer(h, freqs_cis, mask) + + return self.output(self.norm(h)).float() diff --git a/training/tests/torchtune/models/mistral/scripts/mistral_test_config.py b/training/tests/torchtune/models/mistral/scripts/mistral_test_config.py new file mode 100644 index 0000000..0084d3d --- /dev/null +++ b/training/tests/torchtune/models/mistral/scripts/mistral_test_config.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class MistralTestConfig: + BSZ = 2 + SEQ_LEN = 128 + EMBED_DIM = 64 + VOCAB_SIZE = 512 + NUM_LAYERS = 4 + NUM_HEADS = 4 + NUM_KV_HEADS = 2 + INTERMEDIATE_DIM = 512 + MAX_SEQ_LEN = 256 + ROPE_BASE = 10000 + NORM_EPS = 1e-5 + SEED = 16 diff --git a/training/tests/torchtune/models/mistral/test_mistral.py b/training/tests/torchtune/models/mistral/test_mistral.py new file mode 100644 index 0000000..9fb5505 --- /dev/null +++ b/training/tests/torchtune/models/mistral/test_mistral.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import fixed_init_model +from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig +from torchtune.models.mistral import mistral +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(MistralTestConfig.SEED) + + +class TestMistral: + @pytest.fixture + def inputs(self): + return torch.randint( + 0, + MistralTestConfig.VOCAB_SIZE, + (MistralTestConfig.BSZ, MistralTestConfig.SEQ_LEN), + ) + + def test_forward(self, inputs): + model = mistral( + vocab_size=MistralTestConfig.VOCAB_SIZE, + embed_dim=MistralTestConfig.EMBED_DIM, + num_heads=MistralTestConfig.NUM_HEADS, + num_layers=MistralTestConfig.NUM_LAYERS, + num_kv_heads=MistralTestConfig.NUM_KV_HEADS, + max_seq_len=MistralTestConfig.MAX_SEQ_LEN, + intermediate_dim=MistralTestConfig.INTERMEDIATE_DIM, + norm_eps=MistralTestConfig.NORM_EPS, + rope_base=MistralTestConfig.ROPE_BASE, + ) + fixed_init_model(model) + actual = model(inputs) + expected = torch.tensor(18.2749) + assert actual.shape == ( + MistralTestConfig.BSZ, + MistralTestConfig.SEQ_LEN, + MistralTestConfig.VOCAB_SIZE, + ) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) diff --git a/training/tests/torchtune/models/mistral/test_mistral_classifier.py b/training/tests/torchtune/models/mistral/test_mistral_classifier.py new file mode 100644 index 0000000..b30c497 --- /dev/null +++ b/training/tests/torchtune/models/mistral/test_mistral_classifier.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.mistral import mistral_classifier +from torchtune.utils.seed import set_seed + +NUM_LAYERS = 4 +NUM_HEADS = 16 +NUM_KV_HEADS = 8 +VOCAB_SIZE = 32000 +MAX_SEQ_LEN = 2048 +INTERMEDIATE_DIM = 512 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestMistralClassifier: + # expected values are calculated using + # `tests.torchtune.models.scripts.compare_mistral_classifier` + @pytest.mark.parametrize( + "bsz, embed_dim, seq_len, n_classes, expected", + [ + (2, 64, 64, 2, 22.6879), + (1, 256, 256, 1, 110.2561), + ], + ) + def test_forward( + self, bsz: int, embed_dim: int, seq_len: int, n_classes: int, expected: float + ): + inputs = torch.randint(low=0, high=VOCAB_SIZE, size=(bsz, seq_len)) + model = mistral_classifier( + num_classes=n_classes, + vocab_size=VOCAB_SIZE, + num_layers=n_classes, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + intermediate_dim=INTERMEDIATE_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + fixed_init_model(model) + actual = model(inputs) + expected = torch.tensor(expected) + assert actual.shape == (bsz, seq_len, n_classes) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) diff --git a/training/tests/torchtune/models/mistral/test_mistral_prompt_template.py b/training/tests/torchtune/models/mistral/test_mistral_prompt_template.py new file mode 100644 index 0000000..d229630 --- /dev/null +++ b/training/tests/torchtune/models/mistral/test_mistral_prompt_template.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE +from torchtune.data import Message +from torchtune.models.mistral import MistralChatTemplate + + +class TestMistralChatTemplate: + expected_dialogue = [ + Message( + role="user", + content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " + "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " + "How about on an icy road? Well one father in Russia did just that, and recorded " + "the entire thing. To her credit, the child seemed to be doing a great job. " + "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" + "Summary: [/INST] ", + ), + Message( + role="assistant", + content="A father in Russia allowed his 8-year-old child to drive his car on an " + "icy road and recorded the event. The child appeared to be handling the situation well, " + "showcasing their driving skills despite the challenging conditions.", + ), + ] + + def test_format(self): + no_system_sample = MESSAGE_SAMPLE[1:] + actual = MistralChatTemplate()(no_system_sample) + assert_dialogue_equal(actual, self.expected_dialogue) + + def test_format_with_system_prompt_raises(self): + with pytest.raises( + ValueError, match="System prompts are not supported in MistralChatTemplate" + ): + _ = MistralChatTemplate()(MESSAGE_SAMPLE) diff --git a/training/tests/torchtune/models/mistral/test_mistral_tokenizer.py b/training/tests/torchtune/models/mistral/test_mistral_tokenizer.py new file mode 100644 index 0000000..7cc7e8c --- /dev/null +++ b/training/tests/torchtune/models/mistral/test_mistral_tokenizer.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.data import Message +from torchtune.models.mistral import mistral_tokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestMistralTokenizer: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + return mistral_tokenizer(str(ASSETS / "m.model")) + + def test_tokenize_messages(self, tokenizer): + messages = [ + Message( + role="user", + content="Below is an instruction that describes a task. Write a response " + "that appropriately completes the request.\n\n### Instruction:\nGenerate " + "a realistic dating profile bio.\n\n### Response:\n", + masked=True, + ), + Message( + role="assistant", + content="I'm an outgoing and friendly person who loves spending time with " + "friends and family. I'm also a big-time foodie and love trying out new " + "restaurants and different cuisines. I'm a big fan of the arts and enjoy " + "going to museums and galleries. I'm looking for someone who shares my " + "interest in exploring new places, as well as someone who appreciates a " + "good conversation over coffee.", + ), + ] + tokens, mask = tokenizer.tokenize_messages(messages) + expected_tokens = [ + 1, + 323, + 418, + 202, + 31, + 128, + 15, + 120, + 47, + 88, + 584, + 23, + 1665, + 182, + 9, + 434, + 295, + 85, + 4, + 780, + 47, + 636, + 9, + 1094, + 213, + 23, + 9, + 69, + 69, + 164, + 1153, + 299, + 35, + 961, + 132, + 237, + 7, + 5, + 761, + 4, + 12, + 0, + 313, + 120, + 47, + 88, + 584, + 166, + 493, + 171, + 54, + 299, + 9, + 906, + 244, + 19, + 186, + 767, + 303, + 671, + 92, + 209, + 24, + 190, + 52, + 38, + 4, + 12, + 0, + 1243, + 7, + 69, + 135, + 213, + 166, + 6, + 21, + 45, + 128, + 71, + 58, + 38, + 14, + 10, + 652, + 35, + 462, + 101, + 1306, + 7, + 341, + 171, + 20, + 14, + 127, + 26, + 652, + 7, + 10, + 1268, + 4, + 6, + 21, + 45, + 591, + 9, + 566, + 22, + 994, + 913, + 38, + 20, + 52, + 24, + 10, + 1306, + 734, + 14, + 71, + 365, + 1382, + 7, + 10, + 801, + 105, + 88, + 244, + 985, + 7, + 4, + 6, + 21, + 45, + 9, + 566, + 126, + 180, + 11, + 5, + 1137, + 7, + 10, + 1089, + 151, + 8, + 1156, + 213, + 342, + 7, + 10, + 384, + 104, + 54, + 470, + 4, + 6, + 21, + 45, + 287, + 14, + 33, + 125, + 135, + 24, + 101, + 512, + 66, + 7, + 28, + 822, + 15, + 542, + 69, + 59, + 110, + 14, + 365, + 229, + 7, + 3, + 36, + 267, + 36, + 125, + 135, + 24, + 101, + 1503, + 182, + 9, + 222, + 1661, + 191, + 332, + 92, + 92, + 24, + 24, + 4, + 2, + ] + expected_mask = [True] * 75 + [False] * 125 + assert expected_tokens == tokens + assert expected_mask == mask diff --git a/training/tests/torchtune/models/phi3/__init__.py b/training/tests/torchtune/models/phi3/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/models/phi3/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/models/phi3/test_lora_phi3.py b/training/tests/torchtune/models/phi3/test_lora_phi3.py new file mode 100644 index 0000000..148e9c6 --- /dev/null +++ b/training/tests/torchtune/models/phi3/test_lora_phi3.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy + +import pytest +import torch + +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn +from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune import utils +from torchtune.models.phi3 import lora_phi3, phi3 +from torchtune.models.phi3._component_builders import lora_phi3_self_attention +from torchtune.modules.peft import LoRALinear +from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt +from torchtune.utils.seed import set_seed + +RANK = 4 +ALPHA = 1.0 +BSZ = 2 +SEQ_LEN = 32 +EMBED_DIM = 64 +INTER_DIM = 128 +NUM_HEADS = 4 +NUM_KV_HEADS = 2 +MAX_SEQ_LEN = 64 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestLoRAPhi3SelfAttention: + @pytest.fixture + def inputs(self) -> torch.Tensor: + inputs = torch.randn(BSZ, SEQ_LEN, EMBED_DIM) + return inputs + + def get_lora_phi_self_attention(self, lora_modules): + lora_phi_sa = lora_phi3_self_attention( + lora_modules=lora_modules, + embed_dim=EMBED_DIM, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + max_seq_len=MAX_SEQ_LEN, + lora_rank=RANK, + lora_alpha=ALPHA, + ) + fixed_init_model(lora_phi_sa) + return lora_phi_sa + + def test_empty_lora_modules(self): + with pytest.raises(ValueError, match="Must pass one or more of"): + _ = self.get_lora_phi_self_attention([]) + + @pytest.mark.parametrize( + "lora_modules, expected", + [ + (["q_proj", "v_proj"], torch.tensor(51.88655)), + (["q_proj", "k_proj", "v_proj", "output_proj"], torch.tensor(75.80934)), + (["k_proj"], torch.tensor(44.00425)), + ], + ) + def test_forward(self, inputs, lora_modules, expected): + lora_phi_sa = self.get_lora_phi_self_attention(lora_modules) + actual = lora_phi_sa(inputs) + assert_expected(actual.shape, (BSZ, SEQ_LEN, EMBED_DIM)) + assert_expected(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + +class TestLoRAPhi3: + @pytest.fixture + def vocab_size(self): + return 50 + + @pytest.fixture + def inputs(self, vocab_size): + return torch.randint(low=0, high=vocab_size, size=(BSZ, SEQ_LEN)) + + def get_lora_phi3( + self, + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + vocab_size, + reset_norm=True, + quantize_base=False, + embed_dim=EMBED_DIM, + dtype=None, + ): + num_layers = 3 + model = lora_phi3( + lora_attn_modules=lora_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + intermediate_dim=INTER_DIM, + max_seq_len=MAX_SEQ_LEN, + lora_rank=RANK, + lora_alpha=ALPHA, + quantize_base=quantize_base, + ) + # To make final outputs less trivial + if reset_norm: + model.norm = nn.Identity() + + # dtype=None means to just read dtype from parameters + # in the model. This dtype is set explicitly to bf16 currently + # when initializing QLoRA models, as ops such as `arange` aren't + # yet supported with the actual nf4 tensor dtype yet. + fixed_init_model(model, dtype=dtype) + + return model + + def get_ref_phi3(self, vocab_size, embed_dim=EMBED_DIM): + num_layers = 3 + model = phi3( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + intermediate_dim=INTER_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + return model + + @pytest.mark.parametrize( + "lora_modules, apply_lora_to_mlp, apply_lora_to_output, expected", + [ + (["q_proj", "v_proj"], False, False, torch.tensor(2869687.75)), + ( + ["q_proj", "k_proj", "v_proj", "output_proj"], + True, + False, + torch.tensor(10674772.0), + ), + (["k_proj"], True, True, torch.tensor(16285834.0)), + ], + ) + def test_forward( + self, + vocab_size, + inputs, + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + expected, + ): + model = self.get_lora_phi3( + lora_modules, apply_lora_to_mlp, apply_lora_to_output, vocab_size + ) + actual = model(inputs) + assert_expected(actual.shape, (BSZ, SEQ_LEN, vocab_size)) + assert_expected(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + @pytest.mark.parametrize( + "lora_modules, apply_lora_to_mlp, apply_lora_to_output", + [ + (["q_proj", "v_proj"], True, False), + (["q_proj", "k_proj", "v_proj", "output_proj"], False, False), + (["k_proj"], True, True), + ], + ) + def test_lora_phi3_state_dict_parity( + self, lora_modules, apply_lora_to_mlp, apply_lora_to_output, vocab_size + ): + lora_phi = self.get_lora_phi3( + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + vocab_size, + reset_norm=False, + ) + ref_phi = self.get_ref_phi3(vocab_size) + # Ensure ref_phi state_dict can be loaded into lora_phi with only "lora" + # keys missing. + ref_phi_state_dict = ref_phi.state_dict() + missing, unexpected = lora_phi.load_state_dict(ref_phi_state_dict, strict=False) + assert not unexpected + assert all(["lora" in key for key in missing]) + + def test_lora_linear_quantize_base(self): + model = self.get_lora_phi3( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + # quantize_base + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=torch.bfloat16, + ) + for module in model.modules(): + if isinstance(module, LoRALinear): + assert module._quantize_base + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_phi3_parity(self, dtype, inputs): + with utils.set_default_dtype(dtype): + model_ref = self.get_lora_phi3( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=False, + embed_dim=512, + dtype=dtype, + ) + qlora = self.get_lora_phi3( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=dtype, + ) + qlora_sd = qlora.state_dict() + model_ref.load_state_dict(qlora_sd) + # Forward pass of model_ref and qlora should be the same, as QLoRA linear layers should use + # a special linear operator that runs the compute in bf16, but only saves the 4 bit tensors + # for backward. + ref_output = model_ref(inputs) + output = qlora(inputs) + torch.testing.assert_close(ref_output, output) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_phi3_state_dict(self, dtype): + with utils.set_default_dtype(dtype): + model_ref = self.get_lora_phi3( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=False, + embed_dim=512, + dtype=dtype, + ) + high_prec_sd = model_ref.state_dict() + for v in high_prec_sd.values(): + assert v.dtype == dtype + + # ensure quantized LoRA can load a bf16 state_dict + qlora = self.get_lora_phi3( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=dtype, + ) + qlora.load_state_dict(high_prec_sd) + # LoRALinear base weights should be nf4 still + for module in qlora.modules(): + if isinstance(module, LoRALinear): + assert isinstance(module.weight, NF4Tensor) + # saved state_dict should have bf16 weights. + qlora_sd = qlora.state_dict() + for v in qlora_sd.values(): + assert v.dtype == dtype + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_phi3_merged_state_dict(self, dtype): + with utils.set_default_dtype(dtype): + qlora = self.get_lora_phi3( + lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=dtype, + reset_norm=False, # to ensure norm.scale key exists + ) + + qlora_sd = qlora.state_dict() + # Ensure checkpoint merging produces bf16 tensors + merged_ckpt = get_merged_lora_ckpt(deepcopy(qlora_sd), rank=RANK, alpha=ALPHA) + for k, v in merged_ckpt.items(): + # paranoid check for both, as NF4Tensor had issue where NF4Tensor.dtype would return bf16 + assert not isinstance(v, NF4Tensor) + if k == "": + assert v.dtype == torch.float32 + else: + assert v.dtype == dtype + + # Ensure checkpoint can be loaded into non-LoRA model + with utils.set_default_dtype(dtype): + phi3 = self.get_ref_phi3(vocab_size=50, embed_dim=512) + + phi3.load_state_dict(merged_ckpt) diff --git a/training/tests/torchtune/models/phi3/test_phi3.py b/training/tests/torchtune/models/phi3/test_phi3.py new file mode 100644 index 0000000..e4cfca5 --- /dev/null +++ b/training/tests/torchtune/models/phi3/test_phi3.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.phi3 import phi3 +from torchtune.utils.seed import set_seed + +EMBED_DIM = 128 +INTER_DIM = 256 +NUM_LAYERS = 4 +NUM_HEADS = 16 +NUM_KV_HEADS = 8 +VOCAB_SIZE = 32000 +MAX_SEQ_LEN = 2048 +BSZ = 2 +SEQ_LEN = 100 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestPhi3: + @pytest.fixture + def inputs(self): + return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) + + def test_forward(self, inputs): + model = phi3( + vocab_size=VOCAB_SIZE, + num_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + intermediate_dim=INTER_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + fixed_init_model(model, min_val=-0.25, max_val=0.5) + actual = model(inputs) + expected = torch.tensor(3.9763) + assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) diff --git a/training/tests/torchtune/models/phi3/test_phi3_tokenizer.py b/training/tests/torchtune/models/phi3/test_phi3_tokenizer.py new file mode 100644 index 0000000..dfcf724 --- /dev/null +++ b/training/tests/torchtune/models/phi3/test_phi3_tokenizer.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.data import Message +from torchtune.models.phi3 import phi3_mini_tokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestPhi3MiniTokenizer: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + return phi3_mini_tokenizer( + path=str(ASSETS / "m.model"), + ) + + def test_tokenize_messages(self, tokenizer): + messages = [ + Message( + role="user", + content="Below is an instruction that describes a task. Write a response " + "that appropriately completes the request.\n\n### Instruction:\nGenerate " + "a realistic dating profile bio.\n\n### Response:\n", + masked=True, + ), + Message( + role="assistant", + content="I'm an outgoing and friendly person who loves spending time with " + "friends and family. I'm also a big-time foodie and love trying out new " + "restaurants and different cuisines. I'm a big fan of the arts and enjoy " + "going to museums and galleries. I'm looking for someone who shares my " + "interest in exploring new places, as well as someone who appreciates a " + "good conversation over coffee.", + ), + ] + tokens, mask = tokenizer.tokenize_messages(messages) + expected_tokens = [ + 1, + 32010, + 323, + 418, + 202, + 31, + 128, + 15, + 120, + 47, + 88, + 584, + 23, + 1665, + 182, + 9, + 434, + 295, + 85, + 4, + 780, + 47, + 636, + 9, + 1094, + 213, + 23, + 9, + 69, + 69, + 164, + 1153, + 299, + 35, + 961, + 132, + 237, + 7, + 5, + 761, + 4, + 12, + 0, + 313, + 120, + 47, + 88, + 584, + 166, + 493, + 171, + 54, + 299, + 9, + 906, + 244, + 19, + 186, + 767, + 303, + 671, + 92, + 209, + 24, + 190, + 52, + 38, + 4, + 12, + 0, + 1243, + 7, + 69, + 135, + 213, + 166, + 32007, + 32001, + 6, + 21, + 45, + 128, + 71, + 58, + 38, + 14, + 10, + 652, + 35, + 462, + 101, + 1306, + 7, + 341, + 171, + 20, + 14, + 127, + 26, + 652, + 7, + 10, + 1268, + 4, + 6, + 21, + 45, + 591, + 9, + 566, + 22, + 994, + 913, + 38, + 20, + 52, + 24, + 10, + 1306, + 734, + 14, + 71, + 365, + 1382, + 7, + 10, + 801, + 105, + 88, + 244, + 985, + 7, + 4, + 6, + 21, + 45, + 9, + 566, + 126, + 180, + 11, + 5, + 1137, + 7, + 10, + 1089, + 151, + 8, + 1156, + 213, + 342, + 7, + 10, + 384, + 104, + 54, + 470, + 4, + 6, + 21, + 45, + 287, + 14, + 33, + 125, + 135, + 24, + 101, + 512, + 66, + 7, + 28, + 822, + 15, + 542, + 69, + 59, + 110, + 14, + 365, + 229, + 7, + 3, + 36, + 267, + 36, + 125, + 135, + 24, + 101, + 1503, + 182, + 9, + 222, + 1661, + 191, + 332, + 92, + 92, + 24, + 24, + 4, + 32007, + ] + expected_mask = [True] * 77 + [False] * 126 + assert expected_tokens == tokens + assert expected_mask == mask diff --git a/training/tests/torchtune/models/qwen2/__init__.py b/training/tests/torchtune/models/qwen2/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/models/qwen2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/models/qwen2/test_lora_qwen2.py b/training/tests/torchtune/models/qwen2/test_lora_qwen2.py new file mode 100644 index 0000000..b79b8e9 --- /dev/null +++ b/training/tests/torchtune/models/qwen2/test_lora_qwen2.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn +from torchtune.models.qwen2 import lora_qwen2, qwen2 +from torchtune.models.qwen2._component_builders import lora_qwen2_self_attention +from torchtune.utils.seed import set_seed + +RANK = 4 +ALPHA = 1.0 +BSZ = 2 +SEQ_LEN = 32 +EMBED_DIM = 64 +INTERMEDIATE_DIM = 168 +NUM_HEADS = 4 +NUM_KV_HEADS = 2 +MAX_SEQ_LEN = 64 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestLoRAQwen2SelfAttention: + @pytest.fixture + def inputs(self) -> torch.Tensor: + inputs = torch.randn(BSZ, SEQ_LEN, EMBED_DIM) + return inputs + + def get_lora_qwen2_self_attention(self, lora_modules): + lora_qwen2 = lora_qwen2_self_attention( + lora_modules=lora_modules, + embed_dim=EMBED_DIM, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + max_seq_len=MAX_SEQ_LEN, + lora_rank=RANK, + lora_alpha=ALPHA, + ) + fixed_init_model(lora_qwen2) + return lora_qwen2 + + def test_empty_lora_modules(self): + with pytest.raises(ValueError, match="Must pass one or more of"): + _ = self.get_lora_qwen2_self_attention([]) + + @pytest.mark.parametrize( + "lora_modules, expected", + [ + (["q_proj", "v_proj"], torch.tensor(83.6596)), + (["q_proj", "k_proj", "v_proj", "output_proj"], torch.tensor(129.4454)), + (["k_proj"], torch.tensor(69.3473)), + ], + ) + def test_forward(self, inputs, lora_modules, expected): + lora_qwen2_sa = self.get_lora_qwen2_self_attention(lora_modules) + actual = lora_qwen2_sa(inputs) + assert_expected(actual.shape, (BSZ, SEQ_LEN, EMBED_DIM)) + assert_expected(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + +class TestLoRAQwen2: + @pytest.fixture + def vocab_size(self): + return 50 + + @pytest.fixture + def inputs(self, vocab_size): + return torch.randint(low=0, high=vocab_size, size=(BSZ, SEQ_LEN)) + + def get_lora_qwen2( + self, + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + vocab_size, + reset_norm=True, + quantize_base=False, + embed_dim=EMBED_DIM, + dtype=None, + ): + num_layers = 3 + model = lora_qwen2( + lora_attn_modules=lora_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + intermediate_dim=INTERMEDIATE_DIM, + max_seq_len=MAX_SEQ_LEN, + lora_rank=RANK, + lora_alpha=ALPHA, + quantize_base=quantize_base, + ) + # To make final outputs less trivial + if reset_norm: + model.norm = nn.Identity() + + # dtype=None means to just read dtype from parameters + # in the model. This dtype is set explicitly to bf16 currently + # when initializing QLoRA models, as ops such as `arange` aren't + # yet supported with the actual nf4 tensor dtype yet. + fixed_init_model(model, dtype=dtype) + + return model + + def get_ref_qwen2(self, vocab_size, embed_dim=EMBED_DIM): + num_layers = 3 + model = qwen2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=embed_dim, + intermediate_dim=INTERMEDIATE_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + return model + + @pytest.mark.parametrize( + "lora_modules, apply_lora_to_mlp, apply_lora_to_output, expected", + [ + (["q_proj", "v_proj"], False, False, torch.tensor(3736558.0)), + ( + ["q_proj", "k_proj", "v_proj", "output_proj"], + True, + False, + torch.tensor(13962364.0), + ), + (["k_proj"], True, True, torch.tensor(21335964.0)), + ], + ) + def test_forward( + self, + vocab_size, + inputs, + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + expected, + ): + model = self.get_lora_qwen2( + lora_modules, apply_lora_to_mlp, apply_lora_to_output, vocab_size + ) + actual = model(inputs) + assert_expected(actual.shape, (BSZ, SEQ_LEN, vocab_size)) + assert_expected(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + @pytest.mark.parametrize( + "lora_modules, apply_lora_to_mlp, apply_lora_to_output", + [ + (["q_proj", "v_proj"], True, False), + (["q_proj", "k_proj", "v_proj", "output_proj"], False, False), + (["k_proj"], True, True), + ], + ) + def test_lora_qwen2_state_dict_parity( + self, lora_modules, apply_lora_to_mlp, apply_lora_to_output, vocab_size + ): + lora_qwen2 = self.get_lora_qwen2( + lora_modules, + apply_lora_to_mlp, + apply_lora_to_output, + vocab_size, + reset_norm=False, + ) + ref_qwen2 = self.get_ref_qwen2(vocab_size) + # Ensure ref_qwen2 state_dict can be loaded into lora_qwen2 with only "lora" + # keys missing. + ref_qwen2_state_dict = ref_qwen2.state_dict() + missing, unexpected = lora_qwen2.load_state_dict( + ref_qwen2_state_dict, strict=False + ) + assert not unexpected + assert all(["lora" in key for key in missing]) diff --git a/training/tests/torchtune/models/qwen2/test_qwen2.py b/training/tests/torchtune/models/qwen2/test_qwen2.py new file mode 100644 index 0000000..e43e62a --- /dev/null +++ b/training/tests/torchtune/models/qwen2/test_qwen2.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.qwen2 import qwen2 +from torchtune.utils.seed import set_seed + +EMBED_DIM = 128 +INTER_DIM = 256 +NUM_LAYERS = 4 +NUM_HEADS = 16 +NUM_KV_HEADS = 8 +VOCAB_SIZE = 32000 +MAX_SEQ_LEN = 2048 +BSZ = 2 +SEQ_LEN = 100 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestQwen2: + @pytest.fixture + def inputs(self): + return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) + + def test_forward(self, inputs): + model = qwen2( + vocab_size=VOCAB_SIZE, + num_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + intermediate_dim=INTER_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + fixed_init_model(model, min_val=-0.25, max_val=0.5) + actual = model(inputs) + expected = torch.tensor(3.9763) + assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) diff --git a/training/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py b/training/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py new file mode 100644 index 0000000..75203d8 --- /dev/null +++ b/training/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest + +from torchtune.data import Message +from torchtune.models.qwen2 import qwen2_tokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestQwen2Tokenizer: + @pytest.fixture + def tokenizer(self): + return qwen2_tokenizer( + path=str(ASSETS / "tiny_bpe_vocab.json"), + merges_file=str(ASSETS / "tiny_bpe_merges.txt"), + special_tokens_path=str(ASSETS / "tiny_bpe_tokenizer.json"), + ) + + def test_tokenize_messages(self, tokenizer): + messages = [ + Message( + role="user", + content="Below is an instruction that describes a task. Write a response " + "that appropriately completes the request.\n\n### Instruction:\nGenerate " + "a realistic dating profile bio.\n\n### Response:\n", + masked=True, + ), + Message( + role="assistant", + content="I'm an outgoing and friendly person who loves spending time with " + "friends and family. I'm also a big-time foodie and love trying out new " + "restaurants and different cuisines. I'm a big fan of the arts and enjoy " + "going to museums and galleries. I'm looking for someone who shares my " + "interest in exploring new places, as well as someone who appreciates a " + "good conversation over coffee.", + ), + ] + tokens, mask = tokenizer.tokenize_messages(messages) + expected_tokens = [ + 2001, + 273, + 105, + 94, + 33, + 214, + 174, + 156, + 194, + 130, + 197, + 184, + 446, + 789, + 113, + 98, + 1914, + 13, + 346, + 788, + 98, + 706, + 102, + 182, + 184, + 1916, + 176, + 762, + 83, + 113, + 103, + 874, + 269, + 13, + 94, + 94, + 2, + 2, + 2, + 483, + 197, + 25, + 94, + 885, + 98, + 1226, + 1960, + 348, + 114, + 1123, + 399, + 1583, + 78, + 13, + 94, + 94, + 2, + 2, + 2, + 360, + 1733, + 102, + 182, + 25, + 94, + 2002, + 94, + 2001, + 397, + 251, + 249, + 94, + 40, + 1791, + 194, + 453, + 70, + 78, + 114, + 120, + 967, + 176, + 618, + 628, + 1275, + 794, + 294, + 1095, + 445, + 212, + 1356, + 120, + 1299, + 13, + 223, + 1791, + 451, + 98, + 127, + 181, + 1047, + 375, + 915, + 380, + 120, + 1448, + 1732, + 114, + 453, + 447, + 1219, + 64, + 187, + 921, + 120, + 742, + 107, + 84, + 122, + 893, + 13, + 223, + 1791, + 98, + 127, + 181, + 123, + 124, + 131, + 103, + 744, + 82, + 120, + 1506, + 416, + 114, + 128, + 1429, + 182, + 253, + 82, + 120, + 163, + 330, + 105, + 262, + 13, + 223, + 1791, + 155, + 1551, + 171, + 1951, + 628, + 296, + 64, + 237, + 886, + 1390, + 130, + 883, + 1678, + 447, + 306, + 279, + 113, + 11, + 215, + 785, + 215, + 1951, + 628, + 378, + 101, + 66, + 72, + 593, + 98, + 984, + 208, + 1580, + 167, + 510, + 737, + 318, + 1278, + 13, + 2002, + 94, + 2000, + ] + expected_mask = [True] * 67 + [False] * 123 + assert expected_tokens == tokens + assert expected_mask == mask + + formatted_messages = tokenizer.decode(tokens) + expected_formatted_messages = ( + f"<|im_start|>user\n{messages[0].text_content}<|im_end|>\n" + f"<|im_start|>assistant\n{messages[1].text_content}<|im_end|>\n" + "<|endoftext|>" + ) + assert expected_formatted_messages == formatted_messages diff --git a/training/tests/torchtune/modules/__init__.py b/training/tests/torchtune/modules/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/modules/loss/__init__.py b/training/tests/torchtune/modules/loss/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/modules/loss/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/modules/loss/test_dpo_loss.py b/training/tests/torchtune/modules/loss/test_dpo_loss.py new file mode 100644 index 0000000..3e09f27 --- /dev/null +++ b/training/tests/torchtune/modules/loss/test_dpo_loss.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.modules.loss import DPOLoss, IPOLoss, RSOLoss, SimPOLoss + + +@pytest.fixture(autouse=True) +def random(): + torch.manual_seed(16) + + +class TestDPOLosses: + @pytest.fixture + def dpo_loss(self): + return DPOLoss( + beta=0.1, + label_smoothing=0.0, + ) + + @pytest.fixture + def rso_loss(self): + return RSOLoss( + gamma=0.1, + ) + + @pytest.fixture + def ipo_loss(self): + return IPOLoss( + tau=0.1, + ) + + @pytest.fixture + def simpo_loss(self): + return SimPOLoss( + beta=2.0, + gamma=0.5, + label_smoothing=0.0, + ) + + @pytest.fixture + def loss_inputs(self): + """ + kind-of-random inputs for testing the math out (below). + """ + policy_chosen_logprobs = torch.tensor([-0.5, -10.0, -1.0]) + policy_rejected_logprobs = torch.tensor([-0.1, -30.0, -21.0]) + + ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1]) + ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1]) + + return ( + policy_chosen_logprobs, + policy_rejected_logprobs, + ref_chosen_logprobs, + ref_rejected_logprobs, + ) + + def test_dpo_loss(self, dpo_loss, loss_inputs): + """ + here's the maths (see `loss_inputs`): + ratios = torch.tensor([-0.4, 20.0, 20.0]) + ref_ratios = torch.tensor([-0.4, 10, 0.0]) + + logits is ratios - ref_ratios + + logits = torch.tensor([0.0, 10.0, 20.0]) + scaled_logits = torch.tensor([0.0, 1.0, 2.0]) + + since label_smoothing is zero, loss is NLL with temperature scaled logits + logsigmoid is log(1/1+exp(-scaled_logits)) + exp(-scaled_logits) is [1, 1/e, 1/e^2] + logsigmoid is -log([1 / 2, 1 / (1 + 1/e), 1 / (1 + 1/e^2)]) + + expected_losses = -torch.tensor( + [1 / 2, 1 / (1 + torch.exp(torch.tensor(-1.0))), 1 / (1 + torch.exp(torch.tensor(-2.0)))] + ).log() + expected_losses = -expected_logsigmoids + """ + exp_scaled_logits = torch.exp(torch.tensor([0.0, -1.0, -2.0])) + expected_losses = -(1 / (1 + exp_scaled_logits)).log() + losses, *_ = dpo_loss(*loss_inputs) + + torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) + + def test_rso_loss(self, rso_loss, loss_inputs): + """ + # maths: + ratios = torch.tensor([-0.4, 20.0, 20.0]) + ref_ratios = torch.tensor([-0.4, 10, 0.0]) + + # logits is ratios - ref_ratios + + logits = torch.tensor([0.0, 10.0, 20.0]) + scaled_logits = torch.tensor([0.0, 1.0, 2.0]) + + # hinge loss doesn't use label smoothing + # loss = relu(1 - scaled_logits) = max(0, 1 - scaled_logits) + expected_losses = torch.tensor([1.0, 0.0, 0.0]) + """ + + expected_losses = torch.tensor([1.0, 0.0, 0.0]) + + losses, *_ = rso_loss(*loss_inputs) + + torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) + + def test_ipo_loss(self, ipo_loss, loss_inputs): + """ + # maths: + ratios = torch.tensor([-0.4, 20.0, 20.0]) + ref_ratios = torch.tensor([-0.4, 10, 0.0]) + + # logits is ratios - ref_ratios + + logits = torch.tensor([0.0, 10.0, 20.0]) + + # ipo loss is (logits - 1 / (2 * tau)) ** 2 + = [-5, 5, 15] ** 2 + = [25, 25, 225] + """ + expected_losses = torch.tensor([25.0, 25.0, 225.0]) + losses, *_ = ipo_loss(*loss_inputs) + torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) + + def test_simpo_loss(self, simpo_loss, loss_inputs): + """ + here's the maths (see `loss_inputs`): + ratios = torch.tensor([-0.4, 20.0, 20.0]) + gamma_logratios = 0.25 + + logits is ratios - gamma_logratios + + logits = torch.tensor([-0.65, 19.75, 19.75]) + scaled_logits = beta * logits = torch.tensor([-1.3, 39.5, 39.5]) + + since label_smoothing is zero, loss is NLL with temperature scaled logits + """ + policy_chosen_logprobs, policy_rejected_logprobs, *_ = loss_inputs + exp_scaled_logits = torch.exp(torch.tensor([1.3, -39.5, -39.5])) + + expected_losses = -(1 / (1 + exp_scaled_logits)).log() + losses, *_ = simpo_loss(policy_chosen_logprobs, policy_rejected_logprobs) + + torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) diff --git a/training/tests/torchtune/modules/loss/test_ppo_loss.py b/training/tests/torchtune/modules/loss/test_ppo_loss.py new file mode 100644 index 0000000..6445da3 --- /dev/null +++ b/training/tests/torchtune/modules/loss/test_ppo_loss.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.modules.loss import PPOLoss + + +@pytest.fixture(autouse=True) +def random(): + torch.manual_seed(16) + + +class TestPPOLoss: + @pytest.fixture + def loss_fn(self): + return PPOLoss( + value_clip_range=0.2, + value_coeff=0.1, + epsilon=0.2, + ) + + def test_policy_loss_clipped_for_high_logprobs(self, loss_fn): + # fixed old policy logprobs, advantages, returns + pi_old_logprobs = torch.tensor([0.5, 0.8, 1.2]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + pi_logprobs_high = torch.tensor([1.5, 1.8, 2.2]) + # ratio will be [e, e, e] + # clipped ratio becomes [1.2, 1.2, 1.2] (1+epsilon) + # objective becomes max(-e, -1.2) since advantages is 1 + expected_loss = torch.tensor(-1.2) + expected_ratios = torch.exp(torch.ones((3))) + + _, policy_loss, _, ratios, _ = loss_fn( + pi_old_logprobs, pi_logprobs_high, advantages, values, values, returns + ) + + torch.testing.assert_close( + policy_loss.mean(), expected_loss, atol=1e-4, rtol=1e6 + ) + torch.testing.assert_close(ratios, expected_ratios.mean(), atol=1e-4, rtol=1e6) + + def test_policy_loss_clipped_for_low_logprobs(self, loss_fn): + # fixed old policy logprobs, advantages, returns + pi_old_logprobs = torch.tensor([0.5, 0.8, 1.2]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + pi_logprobs_low = torch.tensor([-0.5, -0.2, 0.2]) + # ratio will be [1/e, 1/e, 1/e] (~0.367) + # clipped ratio becomes [0.8, 0.8, 0.8] (1-epsilon) + # objective becomes max(1/e, 0.8) since advantages is 1 + expected_loss = torch.tensor(0.8) + expected_ratios = 1 / torch.exp(torch.ones((3))) + + _, policy_loss, _, ratios, _ = loss_fn( + pi_old_logprobs, pi_logprobs_low, advantages, values, values, returns + ) + + torch.testing.assert_close( + policy_loss.mean(), expected_loss, atol=1e-4, rtol=1e6 + ) + torch.testing.assert_close(ratios, expected_ratios.mean(), atol=1e-4, rtol=1e6) + + def test_policy_loss_not_clipped(self, loss_fn): + # fixed old policy logprobs, advantages, returns + pi_old_logprobs = torch.tensor([0.5, 0.8, 1.2]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + pi_logprobs_unclipped = torch.tensor([0.6, 0.9, 1.3]) + # ratio will be [e^0.1, e^0.1, e^0.1] (~1.1) + # ratio is not clipped since it is within [1-epsilon, 1+epsilon], [0.8, 1.2] + expected_loss = torch.tensor(0.1).exp() + expected_ratios = torch.exp(torch.ones(3) * 0.1) + + _, policy_loss, _, ratios, _ = loss_fn( + pi_old_logprobs, pi_logprobs_unclipped, advantages, values, values, returns + ) + + torch.testing.assert_close( + policy_loss.mean(), expected_loss, atol=1e-4, rtol=1e6 + ) + torch.testing.assert_close(ratios, expected_ratios.mean(), atol=1e-4, rtol=1e6) + + def test_policy_loss_lower_for_higher_advantages(self, loss_fn): + pi_logprobs = torch.tensor([-0.5, -0.8, -1.2]) + + advantages_high = torch.tensor([1.0, 2.0, 3.0]) + advantages_low = torch.tensor([0.5, 1.0, 1.5]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + _, policy_loss_low, *_ = loss_fn( + pi_logprobs, pi_logprobs, advantages_high, values, values, returns + ) + _, policy_loss_high, *_ = loss_fn( + pi_logprobs, pi_logprobs, advantages_low, values, values, returns + ) + + assert policy_loss_low.mean() < policy_loss_high.mean() + + def test_value_loss_lower_for_values_similar_to_return(self, loss_fn): + # fix pi_logrobs, pi_old_logprobs, returns, advantages + pi_logprobs = torch.tensor([-0.5, -0.8, -1.2]) + returns = torch.tensor([1.0, 1.0, 1.0]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + + # values estimates are similar to returns + values_similar = torch.tensor([0.9, 1.0, 1.1]) + # value estimates are less similar to returns + values_less_similar = torch.tensor([0.5, 1.5, 2.0]) + + _, _, value_loss_lower, *_ = loss_fn( + pi_logprobs, + pi_logprobs, + advantages, + values_similar, + values_similar, + returns, + ) + _, _, value_loss_higher, *_ = loss_fn( + pi_logprobs, + pi_logprobs, + advantages, + values_similar, + values_less_similar, + returns, + ) + assert value_loss_lower.mean() < value_loss_higher.mean() diff --git a/training/tests/torchtune/modules/low_precision/__init__.py b/training/tests/torchtune/modules/low_precision/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/modules/low_precision/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/modules/low_precision/test_nf4_dispatch_registration.py b/training/tests/torchtune/modules/low_precision/test_nf4_dispatch_registration.py new file mode 100644 index 0000000..f851efe --- /dev/null +++ b/training/tests/torchtune/modules/low_precision/test_nf4_dispatch_registration.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchao.dtypes import to_nf4 + + +class TestNF4DispatchRegistration: + """ + Class for testing NF4Tensor dispatch ops. + """ + + def test_inplace_copy_copies_expected_attributes(self): + """ + This test ensures that we're copying over all relevant attributes when implementing + torch.ops.aten.copy_.default. If this test fails, we would need to update our implementation + in _register_nf4_dispatch_ops to cover the newly added attributes. + """ + expected_inplace_copy_attrs = [ + "block_size", + "n_blocks", + "scaler_block_size", + "quantized_scalers", + "quantization_factor", + "scaler_mean", + "quantized_data", + "nf4", + ] + + z = to_nf4(torch.rand(512, 512, dtype=torch.bfloat16)) + inplace_copy_attr_set = set(z.__dict__.keys()) + assert set(expected_inplace_copy_attrs) == inplace_copy_attr_set diff --git a/training/tests/torchtune/modules/low_precision/test_nf4_linear.py b/training/tests/torchtune/modules/low_precision/test_nf4_linear.py new file mode 100644 index 0000000..c3b6320 --- /dev/null +++ b/training/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# # Copyright (c) Meta Platforms, Inc. and affiliates. +# # All rights reserved. +# # +# # This source code is licensed under the BSD-style license found in the +# # LICENSE file in the root directory of this source tree. + + +# import bitsandbytes as bnb +# import pytest +# import torch +# from torchao.dtypes.nf4tensor import NF4Tensor +# from torchtune.modules.low_precision import FrozenNF4Linear +# from torchtune.utils.seed import set_seed + + +# @pytest.fixture(autouse=True) +# def random(): +# set_seed(31) + + +# def _build_bnb_linear(input_weight): +# """ +# Builds a bnb.nn.LinearNF4 from a given input weight +# """ +# param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") +# bnb_linear = bnb.nn.LinearNF4( +# input_weight.size(0), input_weight.size(1), bias=False +# ) +# bnb_linear.weight = param +# bnb_linear.cuda() +# return bnb_linear + + +# class TestNF4Linear: +# """ +# Class for testing our NF4Linear implementation. +# """ + +# def test_bias_unsupported(self): +# with pytest.raises(RuntimeError, match="does not currently support biases"): +# _ = FrozenNF4Linear(1, 1, bias=True) + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_parameters(self, dtype): +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# params = list(nf4_linear.parameters()) +# assert len(params) == 1 +# assert isinstance(params[0], NF4Tensor) + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_state_dict(self, dtype): +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# state_dict = nf4_linear.state_dict() +# assert len(state_dict) == 1 +# assert isinstance(state_dict["weight"], NF4Tensor) + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_output_dtype(self, dtype): +# # Test to ensure W4 A16 produces A16 / W4A32 produces A32 +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) +# out = nf4_linear(inp) +# assert out.dtype == dtype + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_backward_dtype(self, dtype): +# # Test to ensure backward pass gives activation a bf16 gradient and no gradient +# # to the linear's weight, as it is frozen. +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) +# nf4_linear(inp).sum().backward() +# assert inp.grad is not None and inp.grad.dtype == dtype +# assert nf4_linear.weight.grad is None + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_nf4_reconstruction_vs_bnb(self, dtype): +# """ +# Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when +# reconstructing the respective original weights. +# """ +# dim = 512 +# nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) +# orig_weight = nf4_linear.weight.get_original_weight().clone().detach() +# bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + +# # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 +# bnb_reconstruction = bnb_nf4_linear( +# torch.eye(dim, dim, dtype=dtype, device="cuda") +# ) +# # Ensure nf4_linear and bnb reconstructions are close to each other. +# assert torch.allclose( +# bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2 +# ) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_nf4_bnb_linear(self, dtype): +# """ +# This test ensures that nf4_linear is "no worse" than BNB by ensuring the +# error compared to a bf16 linear is not more than BNB's implementation. +# """ +# dim = 512 +# nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) +# orig_weight = nf4_linear.weight.get_original_weight().clone().detach() +# bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) +# bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) + +# inp = torch.randn(2, 512, dtype=dtype, device="cuda") + +# out_nf4 = nf4_linear(inp) +# out_bnb = bnb_nf4_linear(inp) +# out_ref = bf16_linear(inp) + +# err_bnb = out_bnb - out_ref +# err_native = out_nf4 - out_ref +# assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2) diff --git a/training/tests/torchtune/modules/peft/__init__.py b/training/tests/torchtune/modules/peft/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/modules/peft/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/modules/peft/test_lora.py b/training/tests/torchtune/modules/peft/test_lora.py new file mode 100644 index 0000000..d48f2be --- /dev/null +++ b/training/tests/torchtune/modules/peft/test_lora.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import pytest + +import torch +from tests.test_utils import fixed_init_model +from torch import nn +from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 +from torchtune import utils +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook +from torchtune.modules.peft import LoRALinear +from torchtune.utils.seed import set_seed + +RANK = 4 +ALPHA = 1.0 +BSZ = 2 +SEQ_LEN = 32 +EXPECTED_VAL = 1.1252 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestLoRALinear: + """ + Class for testing our LoRALinear implementation. Expected values are computed + from the reference implementation and calculated in scripts/compare_lora.py. + """ + + @pytest.fixture + def in_dim(self) -> int: + return 64 + + @pytest.fixture + def out_dim(self) -> int: + return 128 + + @pytest.fixture + def inputs(self, in_dim) -> torch.Tensor: + inputs = torch.randn(BSZ, SEQ_LEN, in_dim) + return inputs + + @pytest.fixture + def lora_linear(self, in_dim, out_dim) -> LoRALinear: + lora_linear = LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=RANK, + alpha=ALPHA, + use_bias=True, + ) + fixed_init_model(lora_linear) + return lora_linear + + @pytest.fixture + def qlora_linear(self, in_dim, out_dim) -> LoRALinear: + with utils.set_default_dtype(torch.bfloat16): + qlora_linear = LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + fixed_init_model(qlora_linear, dtype=torch.bfloat16) + return qlora_linear + + @torch.no_grad() + def set_dummy_weights_for_merge(self, lora_module): + lora_module.lora_a.weight = nn.Parameter( + torch.zeros_like(lora_module.lora_a.weight) + ) + lora_module.lora_b.weight = nn.Parameter( + torch.zeros_like(lora_module.lora_b.weight) + ) + lora_module.weight = nn.Parameter(torch.zeros_like(lora_module.weight)) + lora_module.bias = nn.Parameter(torch.zeros_like(lora_module.bias)) + + # Hardcode some very specific nonzero values to make verification easy + lora_module.weight[4, 5] = 1 + lora_module.bias[7] = 2 + lora_module.lora_a.weight[1, 25] = 3 + lora_module.lora_b.weight[32, 1] = 12 + + def test_forward(self, inputs, lora_linear, out_dim) -> None: + expected = torch.tensor(EXPECTED_VAL) + actual = lora_linear(inputs) + assert actual.shape == (BSZ, SEQ_LEN, out_dim) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + def test_lora_weight_nf4_when_quantized(self, qlora_linear): + assert isinstance(qlora_linear.weight, NF4Tensor) + + def test_quantize_with_bias_raises(self): + with pytest.raises(NotImplementedError, match="does not support bias"): + LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=True, + quantize_base=True, + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_parity(self, dtype): + with utils.set_default_dtype(dtype): + qlora_linear = LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + lora_linear = LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=False, + ) + + # set weight of lora_linear to unquantized weight of qlora_linear and check + # parity. + lora_linear.weight.data = qlora_linear.weight.to(dtype) + + # Ensure forward passes are the same. This is because LoRALinear should use a special + # quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor) + # for autograd. + inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype) + lora_linear_out = lora_linear(inputs) + qlora_linear_out = qlora_linear(inputs) + torch.testing.assert_close(lora_linear_out, qlora_linear_out) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_quantized_state_dict(self, dtype): + with utils.set_default_dtype(dtype): + lora_linear = LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + + lora_linear._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + dtype=dtype, + offload_to_cpu=False, + ) + ) + sd = lora_linear.state_dict() + # No nf4 tensors, all have type dtype + for v in sd.values(): + assert v.dtype == dtype + assert not isinstance(v, NF4Tensor) + + # Load back in results in re-quant and creates the same nf4 tensor. + # This also ensures that LoRALinear can load a bf16 state_dict. + lora_linear_reload = LoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + # Zero out weight to verify reloading works + lora_linear_reload.weight = nn.Parameter( + to_nf4( + torch.zeros_like( + lora_linear.weight.get_original_weight(), + dtype=dtype, + device=lora_linear.weight.device, + ) + ) + ) + # nf4 tensors should be different + assert not torch.allclose( + lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data + ) + # but should be the same after loading + lora_linear_reload.load_state_dict(sd) + assert torch.allclose( + lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data + ) diff --git a/training/tests/torchtune/modules/peft/test_peft_utils.py b/training/tests/torchtune/modules/peft/test_peft_utils.py new file mode 100644 index 0000000..0847b34 --- /dev/null +++ b/training/tests/torchtune/modules/peft/test_peft_utils.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy + +import pytest +import torch + +from torch import nn +from torchtune.models.llama2 import llama2, lora_llama2 +from torchtune.modules.peft import LoRALinear +from torchtune.modules.peft.peft_utils import ( + AdapterModule, + disable_adapter, + get_adapter_params, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, + validate_state_dict_for_lora, +) + +N_LAYERS = 3 +IN_DIM = 5 +OUT_DIM = 10 +VOCAB_SIZE = 50 +NUM_HEADS = 4 +NUM_KV_HEADS = 2 +EMBED_DIM = 64 +MAX_SEQ_LEN = 64 +RANK = 2 +ALPHA = 1 + + +class DummyAdapterModule(nn.Module, AdapterModule): + def __init__(self, in_dim, out_dim): + super().__init__() + self.adapter = nn.Linear(in_dim, out_dim, bias=False) + self.linear = nn.Linear(in_dim, out_dim) + + def adapter_params(self): + return ["adapter.weight"] + + def forward(self, x): + return self.adapter(x) + self.non_adapter(x) + + +class DummyAdapterParentModel(nn.Module, AdapterModule): + def __init__(self, in_dim, out_dim): + super().__init__() + self.dummy_adapter_module = DummyAdapterModule(in_dim, out_dim) + self.parent_adapter = nn.Linear(in_dim, out_dim) + self.parent_base_model = nn.Linear(in_dim, out_dim) + + def adapter_params(self): + return ["parent_adapter.weight", "parent_adapter.bias"] + + def forward(self, x): + return ( + self.dummy_adapter_module(x) + + self.parent_adapter(x) + + self.parent_base_model(x) + ) + + +@pytest.fixture +def dummy_adapter_parent_model(): + return nn.ModuleList( + [DummyAdapterParentModel(IN_DIM, OUT_DIM) for _ in range(N_LAYERS)] + ) + + +@pytest.fixture +def dummy_model_expected_adapter_keys(): + keys = [] + for i in range(N_LAYERS): + keys.extend( + [ + f"{i}.parent_adapter.weight", + f"{i}.parent_adapter.bias", + f"{i}.dummy_adapter_module.adapter.weight", + ] + ) + return keys + + +@pytest.fixture +def dummy_model_expected_base_model_keys(): + keys = [] + for i in range(N_LAYERS): + keys.extend( + [ + f"{i}.parent_base_model.weight", + f"{i}.parent_base_model.bias", + f"{i}.dummy_adapter_module.linear.weight", + f"{i}.dummy_adapter_module.linear.bias", + ] + ) + return keys + + +@pytest.fixture +def lora_llama2_model(): + return lora_llama2( + lora_attn_modules=["q_proj", "v_proj"], + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + lora_rank=4, + lora_alpha=1.0, + ) + + +@pytest.fixture +def lora_llama2_model_all_keys(lora_llama2_model): + return lora_llama2_model.state_dict().keys() + + +@pytest.fixture +def lora_llama2_expected_adapter_keys(): + keys = [] + for i in range(N_LAYERS): + keys.extend( + [ + f"layers.{i}.attn.q_proj.lora_a.weight", + f"layers.{i}.attn.q_proj.lora_b.weight", + f"layers.{i}.attn.v_proj.lora_a.weight", + f"layers.{i}.attn.v_proj.lora_b.weight", + ] + ) + return keys + + +@pytest.fixture +def lora_llama2_expected_base_model_keys(): + + base_model = llama2( + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_KV_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + return base_model.state_dict().keys() + + +class TestPeftUtils: + @pytest.mark.parametrize( + "model_name, expected_keys", + [ + ("dummy_adapter_parent_model", "dummy_model_expected_adapter_keys"), + ("lora_llama2_model", "lora_llama2_expected_adapter_keys"), + ], + ) + def test_get_adapter_params(self, request, model_name, expected_keys): + model = request.getfixturevalue(model_name) + adapter_params = get_adapter_params(model) + expected = request.getfixturevalue(expected_keys) + assert set(expected) == set(adapter_params.keys()) + + @pytest.mark.parametrize( + "model_name, expected_trainable_keys, expected_frozen_keys", + [ + ( + "dummy_adapter_parent_model", + "dummy_model_expected_adapter_keys", + "dummy_model_expected_base_model_keys", + ), + ( + "lora_llama2_model", + "lora_llama2_expected_adapter_keys", + "lora_llama2_expected_base_model_keys", + ), + ], + ) + def test_set_trainable_params( + self, request, model_name, expected_trainable_keys, expected_frozen_keys + ): + model = request.getfixturevalue(model_name) + adapter_params = get_adapter_params(model) + expected_trainable = request.getfixturevalue(expected_trainable_keys) + expected_frozen = request.getfixturevalue(expected_frozen_keys) + set_trainable_params(model, adapter_params) + for k, v in model.named_parameters(): + if k in expected_trainable: + assert v.requires_grad + elif k in expected_frozen: + assert not v.requires_grad + else: + raise AssertionError(f"{k} not in expected keys") + + @pytest.mark.parametrize( + ( + """ + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + full_model_state_dict_keys, + lora_state_dict_keys, + base_model_state_dict_keys, + expected + """ + ), + [ + ( + ["q_proj", "k_proj"], + False, + False, + ["q_proj.lora_a.weight", "dummy_param.weight"], + ["q_proj.lora_a.weight"], + ["dummy_param.weight"], + "", + ), + ( + ["v_proj"], + False, + False, + ["param_a", "param_b"], + None, + ["param_a", "param_b"], + "", + ), + ( + ["output_proj"], + False, + True, + ["output_proj.weight", "output_proj.lora_a.weight"], + ["output_proj.lora_a.weight"], + ["output_proj.weight"], + "", + ), + (["q_proj"], False, False, ["param_a"], [], [], "Missing non-LoRA"), + ( + ["k_proj", "output_proj"], + False, + True, + ["k_proj.lora_a.weight", "param_a"], + ["k_proj.lora_a.weight", "param_a"], + ["param_a"], + "found in LoRA", + ), + ( + ["k_proj"], + False, + False, + ["k_proj.lora_a.weight"], + [], + ["k_proj.lora_a.weight"], + "found in base model", + ), + ( + ["k_proj"], + False, + False, + ["k_proj.lora_a.weight"], + [], + None, + "Missing LoRA", + ), + (["q_proj"], False, False, [], ["a"], ["a"], "overlapping"), + ( + ["v_proj"], + False, + False, + ["dummy_param.weight"], + ["v_proj.lora_a.weight"], + ["dummy_param.weight"], + "Extra", + ), + ( + ["w1", "w2", "w3"], + True, + False, + ["w1.lora_a.weight", "w2.weight", "q_proj.weight"], + ["w1.lora_a.weight"], + ["q_proj.weight"], + "Missing non-LoRA key", + ), + ( + ["q_proj", "output"], + False, + True, + [ + "q_proj.lora_a", + "output.weight", + "output.lora_a", + "output_proj.lora_b", + ], + ["q_proj.lora_a", "output.lora_a", "output_proj.lora_b"], + ["output.weight"], + "Missing non-LoRA key", + ), + ( + ["q_proj", "v_proj"], + False, + False, + "lora_llama2_model_all_keys", + "lora_llama2_expected_adapter_keys", + "lora_llama2_expected_base_model_keys", + "", + ), + ], + ) + def test_validate_lora_state_dict( + self, + request, + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + full_model_state_dict_keys, + lora_state_dict_keys, + base_model_state_dict_keys, + expected, + ): + if isinstance(full_model_state_dict_keys, str): + full_model_state_dict_keys = request.getfixturevalue( + full_model_state_dict_keys + ) + if isinstance(lora_state_dict_keys, str): + lora_state_dict_keys = request.getfixturevalue(lora_state_dict_keys) + if isinstance(base_model_state_dict_keys, str): + base_model_state_dict_keys = request.getfixturevalue( + base_model_state_dict_keys + ) + if expected: + with pytest.raises(AssertionError, match=expected): + validate_state_dict_for_lora( + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + full_model_state_dict_keys=full_model_state_dict_keys, + lora_state_dict_keys=lora_state_dict_keys, + base_model_state_dict_keys=base_model_state_dict_keys, + ) + else: + validate_state_dict_for_lora( + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + full_model_state_dict_keys=full_model_state_dict_keys, + lora_state_dict_keys=lora_state_dict_keys, + base_model_state_dict_keys=base_model_state_dict_keys, + ) + + @pytest.mark.parametrize( + ( + """ + base_missing, + base_unexpected, + lora_missing, + lora_unexpected, + expected + """ + ), + [ + (["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"), + (["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"), + ( + ["k_proj.lora"], + ["output.weight"], + ["q_proj.base_weight"], + [], + "loading base model", + ), + ( + ["k_proj.lora"], + [], + ["q_proj.base_weight"], + ["output.weight"], + "loading adapter", + ), + (["k_proj.lora"], [], ["q_proj.base_weight"], [], ""), + ], + ) + def test_validate_missing_and_unexpected_for_lora( + self, base_missing, base_unexpected, lora_missing, lora_unexpected, expected + ): + lora_attn_modules = ["q_proj", "k_proj"] + apply_lora_to_mlp = True + apply_lora_to_output = False + if expected: + with pytest.raises(AssertionError, match=expected): + validate_missing_and_unexpected_for_lora( + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + base_missing, + base_unexpected, + lora_missing, + lora_unexpected, + ) + else: + validate_missing_and_unexpected_for_lora( + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + base_missing, + base_unexpected, + lora_missing, + lora_unexpected, + ) + + +class TestGetMergedLoRACkpt: + def dummy_model(self): + model = nn.Sequential( + LoRALinear(in_dim=4, out_dim=6, rank=RANK, alpha=ALPHA), + nn.Linear(6, 3), + ) + model[0].lora_a.weight = nn.Parameter( + torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + ) + model[0].lora_b.weight = nn.Parameter( + torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) + ) + model[0].weight = nn.Parameter(3 * torch.ones((6, 4))) + return model + + def test_get_merged_lora_ckpt(self): + dummy_model = self.dummy_model() + merged_sd = get_merged_lora_ckpt( + deepcopy(dummy_model.state_dict()), rank=RANK, alpha=ALPHA + ) + expected_merged_weight = torch.Tensor( + [ + [8.5, 10.0, 11.5, 13.0], + [14.5, 18.0, 21.5, 25.0], + [20.5, 26.0, 31.5, 37.0], + [26.5, 34.0, 41.5, 49.0], + [32.5, 42.0, 51.5, 61.0], + [38.5, 50.0, 61.5, 73.0], + ] + ) + assert merged_sd.keys() == {"0.weight", "1.weight", "1.bias"} + torch.testing.assert_close(merged_sd["0.weight"], expected_merged_weight) + + merged_model = nn.Sequential(nn.Linear(4, 6, bias=False), nn.Linear(6, 3)) + merged_model.load_state_dict(merged_sd, strict=True) + + inputs = torch.randn(2, 8, 4) + torch.testing.assert_close(dummy_model(inputs), merged_model(inputs)) + + +class TestDisableAdapter: + def dummy_model(self): + model_ori = nn.Sequential( + nn.Linear(2, 6, bias=False), + nn.Linear(6, 3), + ) + model_lora = nn.Sequential( + LoRALinear(in_dim=2, out_dim=6, rank=RANK, alpha=ALPHA), + nn.Linear(6, 3), + ) + # TODO: fix weight initialization to use fixed_init_model + for p in model_ori.parameters(): + nn.init.constant_(p, 1.0) + for p in model_lora.parameters(): + nn.init.constant_(p, 1.0) + return model_ori, model_lora + + def test_disable_adapter(self): + model_ori, model_lora = self.dummy_model() + inputs = torch.randn(2, 2) + + ori_outputs = model_ori(inputs) + + with disable_adapter(model_lora): + lora_outputs = model_lora(inputs) + + assert model_lora[0].disabled is False + torch.testing.assert_close(ori_outputs, lora_outputs) diff --git a/training/tests/torchtune/modules/rlhf/__init__.py b/training/tests/torchtune/modules/rlhf/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/modules/rlhf/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/modules/rlhf/test_collate.py b/training/tests/torchtune/modules/rlhf/test_collate.py new file mode 100644 index 0000000..1fecff7 --- /dev/null +++ b/training/tests/torchtune/modules/rlhf/test_collate.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import torch + +from torchtune.modules.rlhf import left_padded_collate + + +class TestLeftPaddedCollate: + def test_left_padded_collate(self): + """ + Tests that input sequences are left-padded to the max seq len. + """ + padding_idx = -8 + tokens = [ + { + "tokens": [ + 1, + 2, + ], + }, + { + "tokens": [3], + }, + { + "tokens": [4, 5, 6, 7], + }, + ] + padded_tokens = left_padded_collate(batch=tokens, padding_idx=padding_idx) + + expected_padded_tokens = torch.tensor( + [ + [padding_idx, padding_idx, 1, 2], + [padding_idx, padding_idx, padding_idx, 3], + [4, 5, 6, 7], + ] + ) + torch.testing.assert_close(padded_tokens, expected_padded_tokens) diff --git a/training/tests/torchtune/modules/rlhf/test_generation.py b/training/tests/torchtune/modules/rlhf/test_generation.py new file mode 100644 index 0000000..2613a4f --- /dev/null +++ b/training/tests/torchtune/modules/rlhf/test_generation.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.llama2 import llama2 +from torchtune.modules import rlhf +from torchtune.utils._generation import sample + + +class TestGenerateNextTokenWithLogits: + @pytest.fixture + def generation_model(self): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.eval() + return model + + def test_generate_next_token_with_logits(self, generation_model): + + inputs = torch.tensor( + [ + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + ] + ) + + input_pos = torch.tensor( + [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + ] + ) + + torch.manual_seed(42) + logits, generation = rlhf.generate_next_token_with_logits( + generation_model, input_pos, inputs + ) + + torch.manual_seed(42) + expected_logits = generation_model(inputs, input_pos=input_pos) + expected_generation = sample(logits[:, -1], temperature=1.0, top_k=None) + + torch.testing.assert_close(logits, expected_logits, atol=1e-4, rtol=1e-5) + torch.testing.assert_close(generation, expected_generation, atol=0, rtol=0) + + +class TestGenerate: + """ + Test class for text generation functionality in :func:`~torchtune.modules.rlhf.generate`. + See `torchtune.tests.utils.test_generation` for context. + """ + + @pytest.fixture + def generation_model(self): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.eval() + return model + + @pytest.fixture + def prompt_tokens(self): + """ + Pytest fixture to create a list of prompt tokens for testing. + """ + return torch.arange(2, 10) + + @pytest.fixture + def prompt_tokens_batched(self): + """ + Pytest fixture to create a list of batched prompt tokens for testing. + """ + return torch.arange(2, 10).repeat(3, 1) + + @pytest.fixture + def prompt_tokens_padded(self): + """ + Pytest fixture to create a list of left-padded prompt tokens for testing. + """ + return torch.cat([torch.tensor([0, 0]), torch.arange(2, 10)]) + + @pytest.fixture + def prompt_tokens_batched_left_padded(self): + """ + Pytest fixture to create a list of left-padded batched prompt tokens for testing. + """ + return torch.cat([torch.tensor([0, 0]), torch.arange(2, 10)]).repeat(3, 1) + + def test_reproducability_with_and_without_padding_batched( + self, + generation_model, + prompt_tokens_batched_left_padded, + prompt_tokens_batched, + ): + """ + Test to check if the `generate` function produces the same output for inputs that are left padded + and for the same inputs that are not left padded, for a batch of inputs with varying sequence lengths. + """ + temperature = 0.6 + top_k = 100 + + torch.manual_seed(42) + outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens_batched_left_padded, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.manual_seed(42) + expected_outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens_batched, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.testing.assert_close(outputs[:, 2:], expected_outputs, atol=0, rtol=0) + + def test_reproducability_with_and_without_padding( + self, generation_model, prompt_tokens, prompt_tokens_padded + ): + """ + Test to check if the `generate` function produces the same output for inputs that are left padded + and for the same inputs that are not left padded. + """ + temperature = 0.6 + top_k = 100 + + torch.manual_seed(42) + + outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens_padded, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.manual_seed(42) + expected_outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.testing.assert_close(outputs[:, 2:], expected_outputs, atol=0, rtol=0) + + +class TestGetCausalMask: + @pytest.fixture + def left_padded_prompt_tokens(self): + """ + Pytest fixture to create a list of left-padded prompt tokens for testing. + """ + return torch.cat([torch.tensor([0, 0]), torch.arange(2, 6)]).unsqueeze(0) + + @pytest.fixture + def left_padded_prompt_tokens_batched(self): + """ + Pytest fixture to create a list of left-padded batched prompt tokens for testing. + """ + return torch.tensor( + [[0, 0, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 1]] + ) + + @pytest.fixture + def right_padded_prompt_tokens(self): + """ + Pytest fixture to create a list of right-padded prompt tokens for testing. + """ + return torch.cat([torch.arange(2, 6), torch.tensor([0, 0])]).unsqueeze(0) + + @pytest.fixture + def right_padded_prompt_tokens_batched(self): + """ + Pytest fixture to create a list of right-padded batched prompt tokens for testing. + """ + return torch.tensor( + [[1, 2, 3, 4, 5, 0], [1, 2, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6]] + ) + + @pytest.fixture + def mixed_padded_prompt_tokens(self): + """ + Pytest fixture to create a list of mixed padded prompt tokens for testing. + """ + return torch.cat( + [torch.tensor([0, 0]), torch.arange(2, 6), torch.tensor([0, 0])] + ).unsqueeze(0) + + @pytest.fixture + def mixed_padded_prompt_tokens_batched(self): + """ + Pytest fixture to create a list of mixed padded batched prompt tokens for testing. + """ + return torch.tensor( + [[0, 0, 1, 2, 0, 0], [0, 1, 2, 3, 4, 0], [0, 0, 0, 1, 0, 0]] + ) + + def test_get_causal_mask_for_left_padded_inputs(self, left_padded_prompt_tokens): + """ + Test to check if the `get_causal_mask` function produces the right output for left-padded prompts. + """ + expected_casual_mask = torch.tensor( + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, True, True, False, False], + [False, False, True, True, True, False], + [False, False, True, True, True, True], + ] + ).unsqueeze(0) + + causal_mask = rlhf.get_causal_mask(left_padded_prompt_tokens != 0) + torch.testing.assert_close(causal_mask, expected_casual_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_left_padded_inputs_batched( + self, left_padded_prompt_tokens_batched + ): + """ + Test to check if the `get_causal_mask` function produces the right output for left-padded batched prompts. + """ + expected_causal_mask = torch.tensor( + [ + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, True, True, False], + [False, False, False, True, True, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, True, False], + [False, True, True, True, True, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + ] + ) + + causal_mask = rlhf.get_causal_mask(left_padded_prompt_tokens_batched != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_right_padded_inputs(self, right_padded_prompt_tokens): + """ + Test to check if the `get_causal_mask` function produces the right output for right-padded prompts. + """ + expected_causal_mask = torch.tensor( + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ] + ).unsqueeze(0) + + causal_mask = rlhf.get_causal_mask(right_padded_prompt_tokens != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_right_padded_inputs_batched( + self, right_padded_prompt_tokens_batched + ): + """ + Test to check if the `get_causal_mask` function produces the right output for right-padded batched prompts. + """ + expected_causal_mask = torch.tensor( + [ + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ], + ] + ) + + causal_mask = rlhf.get_causal_mask(right_padded_prompt_tokens_batched != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_mixed_padding_inputs(self, mixed_padded_prompt_tokens): + """ + Test to check if the `get_causal_mask` function produces the right output for mixed padded prompts. + """ + expected_causal_mask = torch.tensor( + [ + [True, False, False, False, False, False, False, False], + [False, True, False, False, False, False, False, False], + [False, False, True, False, False, False, False, False], + [False, False, True, True, False, False, False, False], + [False, False, True, True, True, False, False, False], + [False, False, True, True, True, True, False, False], + [False, False, False, False, False, False, True, False], + [False, False, False, False, False, False, False, True], + ] + ).unsqueeze(0) + + causal_mask = rlhf.get_causal_mask(mixed_padded_prompt_tokens != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_mixed_padded_inputs_batched( + self, mixed_padded_prompt_tokens_batched + ): + """ + Test to check if the `get_causal_mask` function produces the right output for mixed-padded batched prompts. + """ + expected_causal_mask = torch.tensor( + [ + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, True, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + ] + ) + + causal_mask = rlhf.get_causal_mask(mixed_padded_prompt_tokens_batched != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) diff --git a/training/tests/torchtune/modules/rlhf/test_rewards.py b/training/tests/torchtune/modules/rlhf/test_rewards.py new file mode 100644 index 0000000..0e8ec99 --- /dev/null +++ b/training/tests/torchtune/modules/rlhf/test_rewards.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtune.modules import rlhf + + +class TestGetRewards: + def test_get_rewards(self): + scores = torch.tensor([1.0, 2.0, 3.0]) + logprobs = torch.tensor( + [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.6, 0.7, 0.8], + ] + ) + ref_logprobs = torch.tensor( + [ + [0.2, 0.3, 0.4], + [0.6, 0.7, 0.8], + [0.9, 1.0, 1.1], + ] + ) + kl_controller_value = 0.5 + + # expected kl is logprobs - ref_logprobs + expected_kl = torch.tensor( + [ + [-0.1, -0.1, -0.1], + [-0.2, -0.2, -0.2], + [-0.3, -0.3, -0.3], + ] + ) + + # expected kl_rewards is -kl_controller_value * kl + expected_kl_rewards = torch.tensor( + [ + [0.05, 0.05, 0.05], + [0.1, 0.1, 0.1], + [0.15, 0.15, 0.15], + ] + ) + + # expected rewards is kl_rewards[:, -1] + scores + expected_rewards = torch.tensor( + [ + [0.05, 0.05, 1.05], + [0.1, 0.1, 2.1], + [0.15, 0.15, 3.15], + ] + ) + + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + scores, logprobs, ref_logprobs, kl_controller_value + ) + + torch.testing.assert_close(kl, expected_kl, rtol=1e-4, atol=1e-4) + torch.testing.assert_close( + kl_rewards, expected_kl_rewards, rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close(rewards, expected_rewards, rtol=1e-4, atol=1e-4) + + +class TestWhiten: + def test_whiten_with_shift_mean(self): + x = torch.normal(1, 2, size=(100, 100)) + + expected_mean, expected_var = x.mean(), x.var() # should be ~1.0, ~4.0 + expected = (x - expected_mean) / (torch.sqrt(expected_var) + 1e-8) + expected += expected_mean + output = rlhf.whiten(x, shift_mean=True) + + torch.testing.assert_close(output, expected, rtol=1e-4, atol=1e-4) + + def test_whiten_without_shift_mean(self): + x = torch.normal(1, 2, size=(100, 100)) + + expected_mean, expected_var = x.mean(), x.var() # should be ~1.0, ~4.0 + expected = (x - expected_mean) / (torch.sqrt(expected_var) + 1e-8) + output = rlhf.whiten(x, shift_mean=False) + + torch.testing.assert_close(output, expected, rtol=1e-4, atol=1e-4) + + def test_masked_whiten(self): + x_mean_1 = torch.normal(1, 2, size=(50, 100)) + x_mean_2 = torch.normal(2, 1, size=(50, 100)) + x = torch.cat([x_mean_1, x_mean_2], dim=0) + mask = torch.ones_like(x, dtype=torch.bool) + mask[:50] = False + + expected_mean, expected_var = ( + x_mean_2.mean(), + x_mean_2.var(), + ) # should be ~2.0, ~1.0 + expected = (x - expected_mean) / (torch.sqrt(expected_var) + 1e-8) + expected += expected_mean + + output = rlhf.whiten(x, mask=mask) + + torch.testing.assert_close(output, expected, rtol=1e-4, atol=1e-4) + + +class TestMaskedMean: + def test_masked_single_batch_mean(self): + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + mask = torch.tensor([True, True, True, False, False]) + + expected_mean = torch.tensor(2.0) + output = rlhf.masked_mean(x, mask) + + torch.testing.assert_close(output, expected_mean, rtol=1e-4, atol=1e-4) + + def test_masked_multi_batch_mean(self): + x = torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + ] + ) + mask = torch.tensor( + [[True, True, True, False, False], [False, False, True, True, True]] + ) + + expected_means = torch.tensor([2.0, 5.0]) + output = rlhf.masked_mean(x, mask, dim=1) + + torch.testing.assert_close(output, expected_means, rtol=1e-4, atol=1e-4) + + +class TestMaskedVar: + def test_masked_var(self): + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + mask = torch.tensor([True, True, True, False, False]) + + expected_var = torch.tensor(1.0) + output = rlhf.masked_var(x, mask) + + torch.testing.assert_close(output, expected_var, rtol=1e-4, atol=1e-4) + + +class TestEstimateAdvantages: + def test_estimate_returns(self): + values = torch.tensor([[0, 0, 0, 1]]) + rewards = torch.tensor([[0, 0, 0, 1]]) + gamma = 0.9 + lmbda = 0.95 + + final_reward = 1.0 + expected_returns = torch.tensor( + [ + [ + final_reward * gamma * gamma * gamma * lmbda * lmbda, + final_reward * gamma * gamma * lmbda, + final_reward * gamma, + final_reward, + ] + ] + ) + + _, returns = rlhf.estimate_advantages(values, rewards, gamma, lmbda) + torch.testing.assert_close(returns, expected_returns, rtol=1e-4, atol=1e-4) + + def test_estimate_advantages_with_whitening(self): + values = torch.tensor([[0, 0, 0, 1]]) + rewards = torch.tensor([[0, 0, 0, 1]]) + gamma = 0.9 + lmbda = 0.95 + + final_reward = 1.0 + returns = torch.tensor( + [ + [ + final_reward * gamma * gamma * gamma * lmbda * lmbda, + final_reward * gamma * gamma * lmbda, + final_reward * gamma, + final_reward, + ] + ] + ) + + # see `torchtune.modules.rlhf.estimate_advantages` + expected_advantages = returns - values + expected_whitened_advantages = rlhf.whiten(expected_advantages, shift_mean=True) + advantages, _ = rlhf.estimate_advantages(values, rewards, gamma, lmbda) + torch.testing.assert_close( + expected_whitened_advantages, advantages, rtol=1e-4, atol=1e-4 + ) + + def test_estimate_advantages_with_masks(self): + values = torch.tensor([[0, 0, 0, 1]]) + rewards = torch.tensor([[0, 0, 0, 1]]) + masks = torch.tensor([[True, True, True, False]]) + gamma = 0.9 + lmbda = 0.95 + + final_reward = 1.0 + returns = torch.tensor( + [ + [ + final_reward * gamma * gamma * gamma * lmbda * lmbda, + final_reward * gamma * gamma * lmbda, + final_reward * gamma, + final_reward, + ] + ] + ) + + # see `torchtune.modules.rlhf.estimate_advantages` + expected_advantages = returns - values + expected_advantages = rlhf.whiten(expected_advantages, mask=masks) + expected_advantages[..., -1] = 0.0 + + advantages, _ = rlhf.estimate_advantages( + values, rewards, gamma, lmbda, masks=masks + ) + torch.testing.assert_close( + advantages, expected_advantages, rtol=1e-4, atol=1e-4 + ) diff --git a/training/tests/torchtune/modules/rlhf/test_sequence_processing.py b/training/tests/torchtune/modules/rlhf/test_sequence_processing.py new file mode 100644 index 0000000..43accdf --- /dev/null +++ b/training/tests/torchtune/modules/rlhf/test_sequence_processing.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtune.modules import rlhf + + +class TestTruncateSequenceAtFirstStopToken: + def test_truncate_sequences(self): + stop_token_ids = torch.tensor([2, 869]) + fill_value = 0 + sequences = torch.tensor( + [ + [869, 30, 869], + [2, 30, 869], + [869, 30, 2], + [50, 30, 869], + [13, 30, 2], + [13, 30, 5], + [13, 2, 20], + [13, 2, 2], + [2, 2, 2], + ] + ) + eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( + sequences, stop_token_ids, fill_value + ) + + expected_eos_mask = torch.tensor( + [ + [False, True, True], + [False, True, True], + [False, True, True], + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, True], + [False, False, True], + [False, True, True], + ] + ) + + expected_sequences = torch.tensor( + [ + [869, fill_value, fill_value], + [2, fill_value, fill_value], + [869, fill_value, fill_value], + [50, 30, 869], + [13, 30, 2], + [13, 30, 5], + [13, 2, fill_value], + [13, 2, fill_value], + [2, fill_value, fill_value], + ] + ) + + assert expected_eos_mask.eq(eos_mask).all() + assert expected_sequences.eq(truncated_sequences).all() diff --git a/training/tests/torchtune/modules/test_attention.py b/training/tests/torchtune/modules/test_attention.py new file mode 100644 index 0000000..e63f8f7 --- /dev/null +++ b/training/tests/torchtune/modules/test_attention.py @@ -0,0 +1,347 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import pytest + +import torch + +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn, Tensor + +from torchtune.modules import CausalSelfAttention, KVCache, RotaryPositionalEmbeddings +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestCausalSelfAttention: + """ + Class for testing our CausalSelfAttention implementation. + + The expected tensors are computed from the reference implementation + below by using the same seed, same params and same initialization used + in the fixtures below. + https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450 + """ + + @pytest.fixture + def input_params(self) -> Tuple[int, int, int]: + batch_size = 4 + seq_len = 2048 + embed_dim = 4096 + return batch_size, seq_len, embed_dim + + @pytest.fixture + def input(self, input_params: Tuple[int, int, int]) -> Tensor: + batch_size, seq_len, embed_dim = input_params + x = torch.randn(batch_size, seq_len, embed_dim) + return x + + @pytest.fixture + def attn_params_gqa(self) -> Tuple[int, int, int, int]: + num_heads = 32 + num_kv_heads = 8 + embed_dim = 4096 + max_seq_len = 4096 + return num_heads, num_kv_heads, embed_dim, max_seq_len + + @pytest.fixture + def input_max_len_exceeded( + self, + input_params: Tuple[int, int, int], + attn_params_gqa: Tuple[int, int, int, int], + ) -> Tensor: + batch_size, seq_len, embed_dim = input_params + _, _, _, max_seq_len = attn_params_gqa + seq_len = max_seq_len + 1 + return torch.randn(batch_size, seq_len, embed_dim) + + @pytest.fixture + def input_max_bs_exceeded( + self, + input_params: Tuple[int, int, int], + attn_params_gqa: Tuple[int, int, int, int], + ) -> Tensor: + batch_size, seq_len, embed_dim = input_params + _, _, _, max_seq_len = attn_params_gqa + batch_size += 1 + return torch.randn(batch_size, seq_len, embed_dim) + + @pytest.fixture + def attn_params_mha(self) -> Tuple[int, Optional[int], int, int]: + num_heads = 32 + embed_dim = 4096 + max_seq_len = 4096 + return num_heads, None, embed_dim, max_seq_len + + @pytest.fixture + def attn_params_mqa(self) -> Tuple[int, int, int, int]: + num_heads = 32 + num_kv_heads = 1 + embed_dim = 4096 + max_seq_len = 4096 + return num_heads, num_kv_heads, embed_dim, max_seq_len + + @pytest.fixture + def gqa(self, attn_params_gqa: Tuple[int, int, int, int]) -> CausalSelfAttention: + num_heads, num_kv_heads, embed_dim, max_seq_len = attn_params_gqa + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + ) + + fixed_init_model(attn) + attn.eval() + return attn + + @pytest.fixture + def gqa_kv_cache( + self, attn_params_gqa: Tuple[int, int, int, int] + ) -> CausalSelfAttention: + num_heads, num_kv_heads, embed_dim, max_seq_len = attn_params_gqa + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + head_dim = embed_dim // num_heads + kv_cache = KVCache( + batch_size=4, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + dtype=torch.float32, + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=kv_cache, + max_seq_len=max_seq_len, + ) + fixed_init_model(attn) + attn.eval() + return attn + + @pytest.fixture + def mha(self, attn_params_mha: Tuple[int, int, int, int]) -> CausalSelfAttention: + num_heads, num_kv_heads, embed_dim, max_seq_len = attn_params_mha + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + ) + fixed_init_model(attn) + attn.eval() + return attn + + @pytest.fixture + def mha_kv_cache( + self, attn_params_mha: Tuple[int, int, int, int] + ) -> CausalSelfAttention: + num_heads, num_kv_heads, embed_dim, max_seq_len = attn_params_mha + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + kv_cache = KVCache( + batch_size=4, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + dtype=torch.float32, + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=kv_cache, + max_seq_len=max_seq_len, + ) + fixed_init_model(attn) + attn.eval() + return attn + + @pytest.fixture + def mqa(self, attn_params_mqa: Tuple[int, int, int, int]) -> CausalSelfAttention: + num_heads, num_kv_heads, embed_dim, max_seq_len = attn_params_mqa + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + ) + fixed_init_model(attn) + attn.eval() + return attn + + @pytest.fixture + def mqa_kv_cache( + self, attn_params_mqa: Tuple[int, int, int, int] + ) -> CausalSelfAttention: + num_heads, num_kv_heads, embed_dim, max_seq_len = attn_params_mqa + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + kv_cache = KVCache( + batch_size=4, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + dtype=torch.float32, + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=kv_cache, + max_seq_len=max_seq_len, + ) + fixed_init_model(attn) + attn.eval() + return attn + + def test_forward_gqa(self, input: Tensor, gqa: CausalSelfAttention) -> None: + with torch.no_grad(): + output = gqa(input) + assert_expected( + output.mean(), torch.tensor(-2545.42236328125), atol=1e-8, rtol=1e-3 + ) + assert_expected(output.shape, input.shape) + + def test_forward_gqa_kv_cache( + self, input: Tensor, gqa_kv_cache: CausalSelfAttention, attn_params_gqa + ) -> None: + + _, _, _, max_seq_len = attn_params_gqa + _, seq_len, _ = input.shape + + causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)) + input_pos = torch.arange(seq_len) + mask = causal_mask[None, input_pos] + + with torch.no_grad(): + output = gqa_kv_cache(input, mask=mask, input_pos=input_pos) + assert_expected( + output.mean(), torch.tensor(-2545.42236328125), atol=1e-8, rtol=1e-3 + ) + assert_expected(output.shape, input.shape) + + def test_forward_mha(self, input: Tensor, mha: CausalSelfAttention) -> None: + with torch.no_grad(): + output = mha(input) + assert_expected( + output.mean(), torch.tensor(-2597.248046875), atol=1e-8, rtol=1e-3 + ) + assert_expected(output.shape, input.shape) + + def test_forward_mha_kv_cache( + self, input: Tensor, mha_kv_cache: CausalSelfAttention, attn_params_mha + ) -> None: + + _, _, _, max_seq_len = attn_params_mha + _, seq_len, _ = input.shape + + causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)) + input_pos = torch.arange(seq_len) + mask = causal_mask[None, input_pos] + + with torch.no_grad(): + output = mha_kv_cache(input, mask=mask, input_pos=input_pos) + assert_expected( + output.mean(), torch.tensor(-2597.248046875), atol=1e-8, rtol=1e-3 + ) + assert_expected(output.shape, input.shape) + + def test_forward_mqa(self, input: Tensor, mqa: CausalSelfAttention) -> None: + with torch.no_grad(): + output = mqa(input) + assert_expected( + output.mean(), torch.tensor(-2108.07666015625), atol=1e-8, rtol=1e-3 + ) + assert_expected(output.shape, input.shape) + + def test_forward_mqa_kv_cache( + self, input: Tensor, mqa_kv_cache: CausalSelfAttention, attn_params_mqa + ) -> None: + _, _, _, max_seq_len = attn_params_mqa + _, seq_len, _ = input.shape + + causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)) + input_pos = torch.arange(seq_len) + mask = causal_mask[None, input_pos] + + with torch.no_grad(): + output = mqa_kv_cache(input, mask=mask, input_pos=input_pos) + assert_expected( + output.mean(), torch.tensor(-2108.076660156255), atol=1e-8, rtol=1e-3 + ) + assert_expected(output.shape, input.shape) + + def test_max_seq_len_exceeded( + self, + input_max_len_exceeded: Tensor, + gqa: CausalSelfAttention, + ) -> None: + with pytest.raises(Exception): + _ = gqa(input_max_len_exceeded) + + def test_batch_size_exceeded( + self, + input_max_bs_exceeded: Tensor, + gqa_kv_cache: CausalSelfAttention, + ) -> None: + with pytest.raises(Exception): + _ = gqa_kv_cache(input_max_bs_exceeded) diff --git a/training/tests/torchtune/modules/test_cosine_with_warmup.py b/training/tests/torchtune/modules/test_cosine_with_warmup.py new file mode 100644 index 0000000..274170e --- /dev/null +++ b/training/tests/torchtune/modules/test_cosine_with_warmup.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import pytest + +import torch +import torch.optim as optim + +from tests.test_utils import assert_expected + +from torchtune.modules import get_cosine_schedule_with_warmup + + +class TestCosineLR: + @pytest.fixture + def scheduler(self): + optimizer = optim.SGD([torch.ones(1)], lr=0.2) + scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=10, + num_training_steps=100, + num_cycles=1.0, + ) + return scheduler + + def test_cosine_schedule_init(self, scheduler): + optimizer = scheduler.optimizer + assert_expected(optimizer.param_groups[0]["lr"], 0.0) + + def test_cosine_schedule_mid_warmup(self, scheduler): + optimizer = scheduler.optimizer + scheduler.last_epoch = 5 - 1 + optimizer.step() + scheduler.step() + assert_expected(optimizer.param_groups[0]["lr"], 0.1) + + def test_cosine_schedule_warmup(self, scheduler): + optimizer = scheduler.optimizer + scheduler.last_epoch = 10 - 1 + optimizer.step() + scheduler.step() + assert_expected(optimizer.param_groups[0]["lr"], 0.2) + + def test_cosine_schedule_minimum_value(self, scheduler): + optimizer = scheduler.optimizer + scheduler.last_epoch = 55 - 1 + optimizer.step() + scheduler.step() + assert_expected(optimizer.param_groups[0]["lr"], 0.0) + + def test_cosine_schedule_complete_cycle(self, scheduler): + optimizer = scheduler.optimizer + scheduler.last_epoch = 100 - 1 + optimizer.step() + scheduler.step() + assert_expected(optimizer.param_groups[0]["lr"], 0.2) diff --git a/training/tests/torchtune/modules/test_feed_forward.py b/training/tests/torchtune/modules/test_feed_forward.py new file mode 100644 index 0000000..53b1aed --- /dev/null +++ b/training/tests/torchtune/modules/test_feed_forward.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest + +import torch + +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn, Tensor + +from torchtune.modules import FeedForward +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestFeedForward: + """Class for testing FFN implementation.""" + + @pytest.fixture + def input_params(self) -> Tuple[int, int]: + dim = 4096 + hidden_dim = 11008 # Scaled for SwiGLU + return dim, hidden_dim + + @pytest.fixture + def input(self, input_params: Tuple[int, int]) -> Tensor: + dim, _ = input_params + return torch.randn(1, dim) + + @pytest.fixture + def ffn(self, input_params: Tuple[int, int]) -> FeedForward: + dim, hidden_dim = input_params + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + ff = FeedForward( + gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj + ).eval() + fixed_init_model(ff) + ff.eval() + return ff + + def test_forward(self, input: Tensor, ffn: FeedForward) -> None: + with torch.no_grad(): + x_out = ffn(input) + assert_expected(x_out.mean(), torch.tensor(251.5356), atol=1e-7, rtol=1e-3) + assert_expected(x_out.max(), torch.tensor(503.0614), atol=1e-7, rtol=1e-3) diff --git a/training/tests/torchtune/modules/test_layernorm.py b/training/tests/torchtune/modules/test_layernorm.py new file mode 100644 index 0000000..c1530a1 --- /dev/null +++ b/training/tests/torchtune/modules/test_layernorm.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from tests.test_utils import assert_expected + +from torchtune.modules.layer_norm import Fp32LayerNorm +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestLayerNorm: + """ + Class for testing our LayerNorm, which is just a wrapper around torch.nn.LayerNorm + to support fp16 training. + """ + + @pytest.fixture + def dim(self) -> int: + return 8 + + @pytest.fixture + def eps(self) -> float: + return 1e-6 + + @pytest.fixture + def input_random_fp16(self, dim) -> torch.Tensor: + return torch.randn(dim, dtype=torch.float16) + + @pytest.fixture + def layer_norm(self, dim, eps) -> Fp32LayerNorm: + return Fp32LayerNorm(dim, eps=eps) + + def test_forward_fp16(self, layer_norm, input_random_fp16, eps, dim) -> None: + output_fp16 = layer_norm(input_random_fp16) + + # assert dtype as fp16 + assert ( + output_fp16.dtype == torch.float16 + ), "Expected output to be fp16, but got {output_fp16.dtype=}" + + # assert value as fp32 + expected_output = torch.nn.LayerNorm(dim, eps=eps)(input_random_fp16.float()) + output_fp32 = layer_norm(input_random_fp16.float()) + assert_expected( + output_fp32.mean(), expected_output.mean(), atol=1e-8, rtol=1e-8 + ) diff --git a/training/tests/torchtune/modules/test_position_embeddings.py b/training/tests/torchtune/modules/test_position_embeddings.py new file mode 100644 index 0000000..06b9447 --- /dev/null +++ b/training/tests/torchtune/modules/test_position_embeddings.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest +import torch + +from tests.test_utils import assert_expected +from torch import tensor +from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings + +from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestRotaryPositionEmbedding: + """ + Class for testing our Rotary Positional Embeddings (RoPE) + implementation. The expected tensors are computed from the + reference implementation here: + https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450 + """ + + EXPECTED_X_OUT_MEAN = tensor(6.4543e-05) + EXPECTED_X_OUT_SUM = tensor(2165.7053) + EXPECTED_X_OUT_MAX = tensor(5.4546) + + @pytest.fixture + def input_params(self) -> Tuple[int, int, int, int]: + bsz = 4 + num_heads = 32 + embed_dim = 4096 + head_dim = embed_dim // num_heads + seq_len = 2048 + max_seq_len = 4096 + return bsz, num_heads, head_dim, seq_len, max_seq_len + + @pytest.fixture + def input(self, input_params: Tuple[int, int, int, int]) -> tensor: + bsz, num_heads, head_dim, seq_len, _ = input_params + return torch.randn(bsz, seq_len, num_heads, head_dim) + + @pytest.fixture + def rope( + self, input_params: Tuple[int, int, int, int] + ) -> RotaryPositionalEmbeddings: + _, _, head_dim, _, max_seq_len = input_params + return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + + def test_forward(self, input: tensor, rope: RotaryPositionalEmbeddings) -> None: + x_out = rope(input) + + # check the numerics of the computed tensor + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_forward_with_curr_pos( + self, input: tensor, rope: RotaryPositionalEmbeddings + ) -> None: + ( + _, + seq_len, + _, + _, + ) = input.shape + x_out = rope(input, input_pos=torch.arange(seq_len)) + + # these values should be exactly the same as test_forward + # since in this case input_pos covers the entire input + # sequence. This tests that input_pos works as expected i.e. + # extracts the embeddings for the relevant positions + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_forward_with_packed_pos( + self, input: tensor, rope: RotaryPositionalEmbeddings + ) -> None: + """ + Use input_pos to indicate positions of each token relative to its sequence + when sample is packed. + """ + ( + bsz, + seq_len, + _, + _, + ) = input.shape + x_out = rope( + input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len) + ) + + # these values should be exactly the same as test_forward + # AND test_forward_with_current_pos. In this case input_pos + # covers the entire batch dim and is defined for each sample separately. + # This tests that input_pos works as expected i.e. + # extracts the embeddings for the relevant positions for each sample + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_rope_init_meta_device(self, input_params): + _, _, head_dim, _, max_seq_len = input_params + rope_on_device = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len + ) + with torch.device("meta"): + meta_rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len + ) + + meta_rope._rope_init() + for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()): + torch.testing.assert_close(p1, p2) + + +class TestPhi3RotaryPositionalEmbeddings: + """ + Class for testing the Phi3 models RoPE Embeddings. The expected tensors are + computed from the reference implementation here: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py + """ + + @pytest.fixture + def input_params(self) -> Tuple[int, int, int, int]: + bsz = 4 + num_heads = 32 + embed_dim = 3072 + seq_len = 60 + max_seq_len = 4096 + head_dim = embed_dim // num_heads + return bsz, num_heads, head_dim, seq_len, max_seq_len + + @pytest.fixture + def input(self, input_params: Tuple[int, int, int, int]) -> tensor: + bsz, num_heads, head_dim, seq_len, _ = input_params + return torch.randn(bsz, seq_len, num_heads, head_dim) + + @pytest.fixture + def rope_phi3( + self, input_params: Tuple[int, int, int, int] + ) -> Phi3RotaryPositionalEmbeddings: + _, _, head_dim, _, max_seq_len = input_params + return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + + def test_forward( + self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings + ) -> None: + x_out = rope_phi3(input) + + # check the numerics of the computed tensor + assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4) + assert_expected(x_out.sum(), tensor(-381.0620)) + + # check shapes + assert_expected(x_out.shape, input.shape) diff --git a/training/tests/torchtune/modules/test_rms_norm.py b/training/tests/torchtune/modules/test_rms_norm.py new file mode 100644 index 0000000..6c22892 --- /dev/null +++ b/training/tests/torchtune/modules/test_rms_norm.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from tests.test_utils import assert_expected +from torch.nn.functional import normalize + +from torchtune.modules.rms_norm import RMSNorm +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestRMSNorm: + """ + Class for testing our RMSNorm implementation. Expected tensors + are generated using torch.nn.functional.normalization: + + RMSNorm(x) = normalize(x, p=2, dim=-1) * (dim ** 0.5) + """ + + @pytest.fixture + def dim(self) -> int: + return 8 + + @pytest.fixture + def eps(self) -> float: + return 1e-6 + + @pytest.fixture + def input_ones(self, dim) -> torch.Tensor: + return torch.ones(dim, dtype=torch.float) + + @pytest.fixture + def input_random(self, dim) -> torch.Tensor: + return torch.randn(dim, dtype=torch.float) + + @pytest.fixture + def input_random_fp16(self, dim) -> torch.Tensor: + return torch.randn(dim, dtype=torch.float16) + + @pytest.fixture + def rms_norm(self, dim, eps) -> RMSNorm: + return RMSNorm(dim, eps=eps) + + def test_forward(self, rms_norm, input_ones, input_random, dim) -> None: + output_ones = rms_norm(input_ones) + output_random = rms_norm(input_random) + + expected_random = normalize(input_random, p=2, dim=-1) * (dim**0.5) + + assert_expected(output_ones, input_ones) + assert_expected(output_random, expected_random) + + def test_forward_fp16(self, rms_norm, input_random_fp16, dim) -> None: + output_fp16 = rms_norm(input_random_fp16) + + # convert input to float since rms_norm computes in fp32 + expected_fp16 = normalize(input_random_fp16.float(), p=2, dim=-1) * (dim**0.5) + + assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3) + assert output_fp16.dtype == torch.float32 diff --git a/training/tests/torchtune/modules/test_transformer_decoder.py b/training/tests/torchtune/modules/test_transformer_decoder.py new file mode 100644 index 0000000..3f546bb --- /dev/null +++ b/training/tests/torchtune/modules/test_transformer_decoder.py @@ -0,0 +1,316 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest + +import torch + +from tests.test_utils import assert_expected + +from torch import nn, Tensor + +from torchtune.models.llama2 import llama2 +from torchtune.models.llama2._component_builders import llama2_mlp + +from torchtune.models.llama2._model_utils import scale_hidden_dim_for_mlp +from torchtune.modules import ( + CausalSelfAttention, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestTransformerDecoderLayer: + """ + Class for testing our TransformerDecoderLayer implementation. + + The expected tensors are computed from the reference implementation + below by using the same seed, same params and same initialization used + in the fixtures below. + https://github.com/facebookresearch/llama/blob/main/llama/model.py#L351 + """ + + @pytest.fixture + def input_params(self) -> Tuple[int, int, int]: + batch_size = 4 + seq_len = 2048 + embed_dim = 4096 + return batch_size, seq_len, embed_dim + + @pytest.fixture + def input(self, input_params: Tuple[int, int, int]) -> Tensor: + batch_size, seq_len, embed_dim = input_params + return torch.randn(batch_size, seq_len, embed_dim) + + @pytest.fixture + def layer_params(self) -> Tuple[int, int, int, int]: + num_heads = 32 + num_kv_heads = 8 + embed_dim = 4096 + max_seq_len = 4096 + return num_heads, num_kv_heads, embed_dim, max_seq_len + + @pytest.fixture + def transformer_layer( + self, layer_params: Tuple[int, int, int, int] + ) -> TransformerDecoderLayer: + num_heads, num_kv_heads, embed_dim, max_seq_len = layer_params + head_dim = embed_dim // num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + ) + hidden_dim = scale_hidden_dim_for_mlp(embed_dim) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + transformer_layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim), + mlp_norm=RMSNorm(dim=embed_dim), + ) + # TODO: fix weight initialization to use fixed_init_model + for p in transformer_layer.parameters(): + nn.init.constant_(p, 0.05) + transformer_layer.eval() + return transformer_layer + + def test_forward( + self, input: Tensor, transformer_layer: TransformerDecoderLayer + ) -> None: + with torch.no_grad(): + output = transformer_layer(input) + assert_expected(output.mean(), torch.tensor(18261.0156), atol=1e-8, rtol=1e-3) + assert_expected(output.shape, input.shape) + + +class TestTransformerDecoder: + """ + Class for testing our TransformerDecoderLayer implementation. + + The expected tensors are computed from the reference implementation + below by using the same seed, same params and same initialization used + in the fixtures below. + https://github.com/facebookresearch/llama/blob/main/llama/model.py#L413 + """ + + @pytest.fixture + def input_params(self) -> Tuple[int, int, int]: + batch_size = 4 + seq_len = 512 + vocab_size = 256 + return batch_size, seq_len, vocab_size + + @pytest.fixture + def input(self, input_params: Tuple[int, int, int]) -> Tensor: + batch_size, seq_len, vocab_size = input_params + return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + + @pytest.fixture + def decoder_params(self) -> Tuple[int, int, int, int, int, int]: + vocab_size = 256 + embed_dim = 512 + num_layers = 2 + num_heads = 8 + max_seq_len = 512 + num_kv_heads = 8 + return vocab_size, embed_dim, num_layers, num_heads, max_seq_len, num_kv_heads + + @pytest.fixture + def input_max_len_exceeded( + self, + input_params: Tuple[int, int, int], + decoder_params: Tuple[int, int, int, int, int, int], + ) -> Tensor: + batch_size, seq_len, vocab_size = input_params + _, _, _, _, max_seq_len, _ = decoder_params + seq_len = max_seq_len + 1 + return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + + @pytest.fixture + def input_max_bs_exceeded( + self, + input_params: Tuple[int, int, int], + decoder_params: Tuple[int, int, int, int, int, int], + ) -> Tensor: + batch_size, seq_len, vocab_size = input_params + _, _, _, _, max_seq_len, _ = decoder_params + batch_size = batch_size + 1 + return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + + @pytest.fixture + def decoder( + self, decoder_params: Tuple[int, int, int, int, int, int] + ) -> TransformerDecoder: + ( + vocab_size, + embed_dim, + num_layers, + num_heads, + max_seq_len, + num_kv_heads, + ) = decoder_params + decoder = llama2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + max_seq_len=max_seq_len, + ) + # TODO: fix weight initialization to use fixed_init_model + for p in decoder.parameters(): + nn.init.constant_(p, 0.2) + decoder.eval() + return decoder + + @pytest.fixture + def decoder_with_kv_cache_enabled( + self, decoder_params: Tuple[int, int, int, int, int, int] + ) -> TransformerDecoder: + ( + vocab_size, + embed_dim, + num_layers, + num_heads, + max_seq_len, + num_kv_heads, + ) = decoder_params + decoder = llama2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + max_seq_len=max_seq_len, + ) + # TODO: fix weight initialization to use fixed_init_model + for p in decoder.parameters(): + nn.init.constant_(p, 0.2) + decoder.eval() + decoder.setup_caches(batch_size=4, dtype=torch.float32) + return decoder + + def test_forward( + self, + input: Tensor, + input_params: Tuple[int, int, int], + decoder: TransformerDecoder, + ) -> None: + batch_size, seq_len, vocab_size = input_params + with torch.no_grad(): + output = decoder(input) + assert_expected(output.mean(), torch.tensor(20.4800), atol=1e-8, rtol=1e-6) + assert_expected(output.shape, torch.Size([batch_size, seq_len, vocab_size])) + + def test_max_seq_len_exceeded( + self, + input_max_len_exceeded: Tensor, + decoder: TransformerDecoder, + ) -> None: + with pytest.raises(Exception): + output = decoder(input_max_len_exceeded) + + def test_kv_cache( + self, + input: Tensor, + decoder_with_kv_cache_enabled: TransformerDecoder, + decoder: TransformerDecoder, + ) -> None: + _, seq_len = input.shape + input_pos = torch.arange(seq_len) + + with torch.no_grad(): + output_cache = decoder_with_kv_cache_enabled(input, input_pos=input_pos) + output_no_cache = decoder(input) + assert_expected(output_cache.mean(), output_no_cache.mean()) + + def test_kv_cache_reset_values( + self, + input: Tensor, + decoder_with_kv_cache_enabled: TransformerDecoder, + ) -> None: + _, seq_len = input.shape + input_pos = torch.arange(seq_len) + + with torch.no_grad(): + _ = decoder_with_kv_cache_enabled(input, input_pos=input_pos) + kv_cache_k_val = decoder_with_kv_cache_enabled.layers[ + 0 + ].attn.kv_cache.k_cache.clone() + kv_cache_v_val = decoder_with_kv_cache_enabled.layers[ + 0 + ].attn.kv_cache.v_cache.clone() + + decoder_with_kv_cache_enabled.reset_caches() + kv_cache_k_val_reset = decoder_with_kv_cache_enabled.layers[ + 0 + ].attn.kv_cache.k_cache.clone() + kv_cache_v_val_reset = decoder_with_kv_cache_enabled.layers[ + 0 + ].attn.kv_cache.v_cache.clone() + + assert not torch.allclose(kv_cache_k_val, kv_cache_k_val_reset) + assert not torch.allclose(kv_cache_v_val, kv_cache_v_val_reset) + + def test_kv_cache_reset_values_fails_when_not_enabled_first( + self, + decoder: TransformerDecoder, + ) -> None: + with pytest.raises(RuntimeError, match="Key value caches are not setup"): + decoder.reset_caches() + + def test_kv_cache_batch_size_exceeded( + self, + input_max_bs_exceeded: Tensor, + decoder_with_kv_cache_enabled: TransformerDecoder, + ) -> None: + with pytest.raises(ValueError): + decoder_with_kv_cache_enabled(input_max_bs_exceeded) + + def test_rms_norm_propagation( + self, decoder_params: Tuple[int, int, int, int, int, int] + ): + ( + vocab_size, + embed_dim, + num_layers, + num_heads, + max_seq_len, + num_kv_heads, + ) = decoder_params + rms_norm_eps = 1e-2 + decoder = llama2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + max_seq_len=max_seq_len, + norm_eps=rms_norm_eps, + ) + rms_norms = [m for m in decoder.modules() if isinstance(m, RMSNorm)] + assert len(rms_norms) > 0 + for rms_norm in rms_norms: + assert rms_norm.eps == rms_norm_eps diff --git a/training/tests/torchtune/modules/test_vision_transformer.py b/training/tests/torchtune/modules/test_vision_transformer.py new file mode 100644 index 0000000..91260e5 --- /dev/null +++ b/training/tests/torchtune/modules/test_vision_transformer.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from tests.test_utils import assert_expected, fixed_init_model, fixed_init_tensor +from torchtune.models.clip._component_builders import clip_vision_encoder + + +@pytest.fixture +def transformer_config(): + return { + "embed_dim": 32, + "cls_output_dim": 64, + "num_layers": 2, + "num_heads": 4, + "tile_size": 49, + "patch_size": 9, + "max_num_tiles": 4, + "in_channels": 3, + "output_cls_projection": False, + "out_indices": None, + } + + +@pytest.fixture +def vision_transformer(transformer_config): + vision_transformer = clip_vision_encoder(**transformer_config).eval() + fixed_init_model(vision_transformer, min_val=-1, max_val=1) + return vision_transformer + + +class TestVisionTransformer: + @pytest.fixture(autouse=True) + def setup_class(self, transformer_config): + self.batch_size = 1 + self.n_imgs = 2 + num_channels = transformer_config["in_channels"] + + # generate aspect ratios up to max_num_tiles, shape (bsz, num_conccurent_media, 2) + self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape( + self.batch_size, self.n_imgs, 2 + ) + + self.num_tiles = 4 + assert ( + self.num_tiles <= transformer_config["max_num_tiles"] + ), "For this test to be valid, num_tiles should be <= max_num_tiles" + assert ( + torch.prod(self.aspect_ratio, dim=-1).max() <= self.num_tiles + ), "For this test to be vlaid, prod(aspect_ratio).max() should match num_tiles" + + # generate image + image = torch.rand( + ( + self.batch_size, + self.n_imgs, + self.num_tiles, + num_channels, + transformer_config["tile_size"], + transformer_config["tile_size"], + ) + ) + self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1) + + def test_vision_transformer_without_hidden_layers( + self, vision_transformer, transformer_config + ): + # call model + output, _ = vision_transformer(self.image, self.aspect_ratio) + + # assertion + expected_shape = ( + self.batch_size, + self.n_imgs, + self.num_tiles, + vision_transformer.get_image_tokens_per_tile(), + transformer_config["embed_dim"], + ) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + assert_expected(output.mean(), torch.tensor(1.0172), atol=1e-3, rtol=1e-3) + + def test_fails_if_ar_none_and_multiple_tiles(self, vision_transformer): + """ + If aspect_ratio is none, then num_tiles shouldnt be greater than 1. + Here the test passes if something actually fails under these conditions. + """ + assert self.image.shape[2] > 1, "This test is not valid for num_tiles=1" + try: + vision_transformer(self.image, aspect_ratio=None) + pytest.fail( + "Expected ValueError: If num_tiles>1, aspect_ratio should not be None" + ) + except ValueError: + pass # If ValueError is raised, the test passes + + def test_vision_transformer_with_cls_projection(self, transformer_config): + transformer_config = transformer_config.copy() + transformer_config["output_cls_projection"] = True + + # call model + model_with_cls = clip_vision_encoder(**transformer_config).eval() + fixed_init_model(model_with_cls, min_val=-1, max_val=1) + output, _ = model_with_cls(self.image, self.aspect_ratio) + + # assertion + expected_shape = ( + self.batch_size, + self.n_imgs, + self.num_tiles, + 1, + transformer_config["cls_output_dim"], + ) + + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + assert_expected(output.mean(), torch.tensor(9.6240), atol=1e-3, rtol=1e-3) + + def test_vision_transformer_return_hidden_layers(self, transformer_config): + transformer_config = transformer_config.copy() + transformer_config["out_indices"] = [ + 0, + 1, + ] + + # call model + model_with_hidden = clip_vision_encoder(**transformer_config) + fixed_init_model(model_with_hidden, min_val=-1, max_val=1) + x, hidden_layers = model_with_hidden(self.image, self.aspect_ratio) + + # assertion x + expected_shape_x = ( + self.batch_size, + self.n_imgs, + self.num_tiles, + model_with_hidden.get_image_tokens_per_tile(), + transformer_config["embed_dim"], + ) + + assert ( + x.shape == expected_shape_x + ), f"Expected shape {expected_shape_x}, but got {x.shape=}" + + assert_expected(x.mean(), torch.tensor(1.0172), atol=1e-3, rtol=1e-3) + + # assertion hidden + num_hidden_layers_expected = len(transformer_config["out_indices"]) + + expected_shape_hidden_layers = ( + self.batch_size, + self.n_imgs, + self.num_tiles, + model_with_hidden.get_image_tokens_per_tile(), + transformer_config["embed_dim"], + ) + + assert ( + len(hidden_layers) == num_hidden_layers_expected + ), f"Expected {num_hidden_layers_expected} hidden layers, but got {len(hidden_layers)}" + + for hidden_layer in hidden_layers: + assert ( + hidden_layer.shape == expected_shape_hidden_layers + ), f"Expected shape {expected_shape_hidden_layers}, but got {hidden_layer.shape=}" + + assert_expected( + torch.stack(hidden_layers, dim=-1).mean(), + torch.tensor(6.6938), + atol=1e-3, + rtol=1e-3, + ) + + def test_vision_transformer_single_tile(self, transformer_config): + transformer_config = transformer_config.copy() + transformer_config["max_num_tiles"] = 1 + + # get single tile: (bsz, n_imgs, 1, num_channels, tile_size, tile_size) + images = self.image[:, :, [0], :, :, :] + + # call model + model_with_multiple_tiles = clip_vision_encoder(**transformer_config) + fixed_init_model(model_with_multiple_tiles, min_val=-1, max_val=1) + output, _ = model_with_multiple_tiles(images, aspect_ratio=None) + + # assertion + expected_shape = ( + self.batch_size, + self.n_imgs, + 1, + model_with_multiple_tiles.get_image_tokens_per_tile(), + transformer_config["embed_dim"], + ) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + assert_expected(output.mean(), torch.tensor(0.5458), atol=1e-3, rtol=1e-3) diff --git a/training/tests/torchtune/modules/tokenizers/test_sentencepiece.py b/training/tests/torchtune/modules/tokenizers/test_sentencepiece.py new file mode 100644 index 0000000..5600282 --- /dev/null +++ b/training/tests/torchtune/modules/tokenizers/test_sentencepiece.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.modules.tokenizers import SentencePieceBaseTokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestSentencePieceBaseTokenizer: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + sp_tokenizer = SentencePieceBaseTokenizer(str(ASSETS / "m.model")) + return sp_tokenizer + + def test_encode(self, tokenizer): + assert tokenizer.encode("Hello world!") == [ + tokenizer.bos_id, + 12, + 1803, + 1024, + 103, + tokenizer.eos_id, + ] + assert tokenizer.encode("Hello world!", add_eos=False) == [ + tokenizer.bos_id, + 12, + 1803, + 1024, + 103, + ] + assert tokenizer.encode("Hello world!", add_bos=False) == [ + 12, + 1803, + 1024, + 103, + tokenizer.eos_id, + ] + assert tokenizer.encode("Hello world!", add_eos=False, add_bos=False) == [ + 12, + 1803, + 1024, + 103, + ] + + def test_decode(self, tokenizer): + assert tokenizer.decode([1, 12, 1803, 1024, 103, 2]) == "Hello world!" + + def test_token_ids(self, tokenizer): + assert tokenizer.eos_id == 2 + assert tokenizer.pad_id == -1 + assert tokenizer.bos_id == 1 + + def test_tokenizer_vocab_size(self, tokenizer): + assert tokenizer.vocab_size == 2000 + + def test_encode_without_leading_whitespace(self, tokenizer): + s1 = "Hello" + s2 = "I'm an outgoing and friendly person." + # TODO: investigate why test tokenizer model does not encode whitespace + tokenizer.encodes_whitespace = True + s1_tokens = tokenizer.encode(s1, add_bos=False, add_eos=False) + s2_tokens = tokenizer.encode(s2, add_bos=False, add_eos=False) + # Set prefix="pre" since "\n" is not in the test tokenizer's vocab + s2_tokens_no_whitespace = tokenizer.encode( + s2, add_bos=False, add_eos=False, trim_leading_whitespace=True, prefix="pre" + ) + s1s2_tokens = tokenizer.encode(s1 + s2, add_bos=False, add_eos=False) + assert (s1_tokens + s2_tokens) != s1s2_tokens + assert (s1_tokens + s2_tokens_no_whitespace) == s1s2_tokens diff --git a/training/tests/torchtune/modules/tokenizers/test_tiktoken.py b/training/tests/torchtune/modules/tokenizers/test_tiktoken.py new file mode 100644 index 0000000..1d65636 --- /dev/null +++ b/training/tests/torchtune/modules/tokenizers/test_tiktoken.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import pytest +from torchtune.models.llama3._tokenizer import CL100K_PATTERN +from torchtune.modules.tokenizers import TikTokenBaseTokenizer + +ASSETS = Path(__file__).parent.parent.parent.parent / "assets" + + +class TestTikTokenBaseTokenizer: + @pytest.fixture + def tokenizer(self): + # Pretrained tiktoken model generated via the script in + # https://gist.github.com/ebsmothers/54b133dd87db6679b14318545aaa2de4 + return TikTokenBaseTokenizer( + path=str(ASSETS / "tiktoken_small.model"), + name="test_tiktoken", + pattern=CL100K_PATTERN, + bos_id=0, + eos_id=-1, + special_tokens={ + "<|test_token_0|>": 2000, + "<|test_token_1|>": 2001, + }, + ) + + @pytest.fixture + def texts(self): + return [ + "I can see the sun. But even if I cannot see the sun, I know that it exists.", + "And to know that the sun is there - that is living.", + ] + + @pytest.fixture + def token_ids(self): + return [ + 0, + 73, + 503, + 654, + 262, + 376, + 110, + 46, + 690, + 720, + 428, + 270, + 1119, + 654, + 262, + 376, + 110, + 44, + 270, + 686, + 334, + 312, + 522, + 511, + 115, + 46, + -1, + ] + + def test_encode(self, tokenizer, texts, token_ids): + assert tokenizer.encode(texts[0]) == token_ids + + def test_decode(self, tokenizer, texts, token_ids): + assert tokenizer.decode(token_ids) == texts[0] + + def test_encode_and_decode(self, tokenizer, texts): + token_ids = tokenizer.encode(texts[0]) + decoded_text = tokenizer.decode(token_ids) + assert texts[0] == decoded_text + + def test_tokenizer_vocab_size(self, tokenizer): + assert tokenizer.base_vocab_size == 2000 + assert tokenizer.vocab_size == 2002 + + def test_split_long_repetitions(self, tokenizer): + normal_str = "Here is a normal string" + ten_spaces = "".join(10 * [" "]) + space_str = ten_spaces.join( + ["Here", "is", "a", "string", "with", "long", "spaces"] + ) + no_space_str = "".join(10 * ["ab"]) + + actual_split = tokenizer._split_long_repetitions(normal_str, 5) + expected_split = ["Here is a norma", "l strin", "g"] + for actual_substr, expected_substr in zip(actual_split, expected_split): + assert actual_substr == expected_substr + with pytest.raises(StopIteration): + next(actual_split) + + actual_split = tokenizer._split_long_repetitions(space_str, 9) + expected_split = [ + "Here" + ten_spaces[:-1], + " is" + ten_spaces[:-1], + " a" + ten_spaces[:-1], + " string" + ten_spaces[:-1], + " with" + ten_spaces[:-1], + " long" + ten_spaces[:-1], + " spaces", + ] + for actual_substr, expected_substr in zip(actual_split, expected_split): + assert actual_substr == expected_substr + with pytest.raises(StopIteration): + next(actual_split) + + actual_split = tokenizer._split_long_repetitions(no_space_str, 4) + expected_split = ["abab"] * 5 + for actual_substr, expected_substr in zip(actual_split, expected_split): + assert actual_substr == expected_substr + with pytest.raises(StopIteration): + next(actual_split) diff --git a/training/tests/torchtune/modules/transforms/test_get_canvas_best_fit.py b/training/tests/torchtune/modules/transforms/test_get_canvas_best_fit.py new file mode 100644 index 0000000..16779ed --- /dev/null +++ b/training/tests/torchtune/modules/transforms/test_get_canvas_best_fit.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.modules.transforms import find_supported_resolutions, get_canvas_best_fit + + +class TestUtils: + @pytest.mark.parametrize( + "params", + [ + { + "max_num_tiles": 1, + "tile_size": 224, + "expected_resolutions": [(224, 224)], + }, + { + "max_num_tiles": 2, + "tile_size": 100, + "expected_resolutions": [(100, 200), (200, 100), (100, 100)], + }, + { + "max_num_tiles": 3, + "tile_size": 50, + "expected_resolutions": [ + (50, 150), + (150, 50), + (50, 100), + (100, 50), + (50, 50), + ], + }, + { + "max_num_tiles": 4, + "tile_size": 300, + "expected_resolutions": [ + (300, 1200), + (600, 600), + (300, 300), + (1200, 300), + (300, 900), + (900, 300), + (300, 600), + (600, 300), + ], + }, + ], + ) + def test_find_supported_resolutions(self, params): + max_num_tiles = params["max_num_tiles"] + tile_size = params["tile_size"] + expected_resolutions = params["expected_resolutions"] + resolutions = find_supported_resolutions(max_num_tiles, tile_size) + + assert len(set(resolutions)) == len(resolutions), "Resolutions should be unique" + assert set(resolutions) == set( + expected_resolutions + ), f"Expected resolutions {expected_resolutions} but got {resolutions}" + + @pytest.mark.parametrize( + "params", + [ + { + "image_size": (800, 600), + "possible_resolutions": [ + (224, 896), + (448, 448), + (224, 224), + (896, 224), + (224, 672), + (672, 224), + (224, 448), + (448, 224), + ], + "resize_to_max_canvax": False, + "expected_best_resolution": (448, 448), + }, + { + "image_size": (200, 300), + "possible_resolutions": [ + (224, 896), + (448, 448), + (224, 224), + (896, 224), + (224, 672), + (672, 224), + (224, 448), + (448, 224), + ], + "resize_to_max_canvax": False, + "expected_best_resolution": (224, 448), + }, + { + "image_size": (200, 500), + "possible_resolutions": [ + (224, 896), + (448, 448), + (224, 224), + (896, 224), + (224, 672), + (672, 224), + (224, 448), + (448, 224), + ], + "resize_to_max_canvax": True, + "expected_best_resolution": (224, 672), + }, + { + "image_size": (200, 200), + "possible_resolutions": [ + (224, 896), + (448, 448), + (224, 224), + (896, 224), + (224, 672), + (672, 224), + (224, 448), + (448, 224), + ], + "resize_to_max_canvax": False, + "expected_best_resolution": (224, 224), + }, + { + "image_size": (200, 100), + "possible_resolutions": [ + (224, 896), + (448, 448), + (224, 224), + (896, 224), + (224, 672), + (672, 224), + (224, 448), + (448, 224), + ], + "resize_to_max_canvax": True, + "expected_best_resolution": (448, 224), + }, + ], + ) + def test_get_canvas_best_fit(self, params): + image_size = params["image_size"] + possible_resolutions = params["possible_resolutions"] + expected_best_resolution = params["expected_best_resolution"] + resize_to_max_canvax = params["resize_to_max_canvax"] + + possible_resolutions = torch.tensor(possible_resolutions) + + image = torch.rand(*image_size) + best_resolution = get_canvas_best_fit( + image, possible_resolutions, resize_to_max_canvax + ) + + assert ( + tuple(best_resolution) == expected_best_resolution + ), f"Expected best resolution {expected_best_resolution} but got {best_resolution}" diff --git a/training/tests/torchtune/modules/transforms/test_resize_with_pad.py b/training/tests/torchtune/modules/transforms/test_resize_with_pad.py new file mode 100644 index 0000000..29b4fce --- /dev/null +++ b/training/tests/torchtune/modules/transforms/test_resize_with_pad.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +import torchvision + +from torchtune.modules.transforms import resize_with_pad + + +class TestTransforms: + @pytest.mark.parametrize( + "params", + [ + { + "image_size": (200, 100), + "target_size": (1000, 1200), + "max_upscaling_size": 600, + "expected_resized_size": (600, 300), + }, + { + "image_size": (2000, 200), + "target_size": (1000, 1200), + "max_upscaling_size": 600, + "expected_resized_size": (1000, 100), + }, + { + "image_size": (400, 200), + "target_size": (1000, 1200), + "max_upscaling_size": 2000, + "expected_resized_size": (1000, 500), + }, + { + "image_size": (400, 200), + "target_size": (1000, 1200), + "max_upscaling_size": None, + "expected_resized_size": (1000, 500), + }, + { + "image_size": (1000, 500), + "target_size": (400, 300), + "max_upscaling_size": None, + "expected_resized_size": [400, 200], + }, + ], + ) + def test_resize_with_pad(self, params): + + image_size = params["image_size"] + target_size = params["target_size"] + max_upscaling_size = params["max_upscaling_size"] + expected_resized_size = params["expected_resized_size"] + + image = torch.rand(3, *image_size) # Create a random image tensor + + resized_image = resize_with_pad( + image=image, + target_size=target_size, + resample=torchvision.transforms.InterpolationMode["BILINEAR"], + max_upscaling_size=max_upscaling_size, + ) + + # assert everything beyond resize has value == 0 + assert torch.all( + resized_image[:, (expected_resized_size[0] + 1) :, :] == 0 + ), "Expected everything beyond resize to be pad with fill=0" + + assert torch.all( + resized_image[:, :, (expected_resized_size[1] + 1) :] == 0 + ), "Expected everything beyond resize to be pad with fill=0" + + assert torch.all( + resized_image[:, : expected_resized_size[0], : expected_resized_size[1]] + != 0 + ), "Expected no padding where the image is supposed to be" + + # output should have shape target_size + assert ( + resized_image.shape[-2:] == target_size + ), f"Expected output with shape {target_size} but got {resized_image.shape[-2:]}" diff --git a/training/tests/torchtune/modules/transforms/test_tile_crop.py b/training/tests/torchtune/modules/transforms/test_tile_crop.py new file mode 100644 index 0000000..7afde49 --- /dev/null +++ b/training/tests/torchtune/modules/transforms/test_tile_crop.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from torchtune.modules.transforms import tile_crop + + +class TestTransforms: + @pytest.mark.parametrize( + "params", + [ + { + "expected_output_shape": torch.Size([24, 3, 50, 50]), + "image_size": (3, 200, 300), + "status": "Passed", + "tile_size": 50, + }, + { + "expected_output_shape": torch.Size([6, 3, 200, 200]), + "image_size": (3, 400, 600), + "status": "Passed", + "tile_size": 200, + }, + { + "expected_output_shape": torch.Size([1, 3, 250, 250]), + "image_size": (3, 250, 250), + "status": "Passed", + "tile_size": 250, + }, + { + "error": "Image size 250x250 is not divisible by tile size 500", + "image_size": (3, 250, 250), + "status": "Failed", + "tile_size": 500, + }, + { + "error": "Image size 250x250 is not divisible by tile size 80", + "image_size": (3, 250, 250), + "status": "Failed", + "tile_size": 80, + }, + ], + ) + def test_tile_crop(self, params): + image_size = params["image_size"] + tile_size = params["tile_size"] + status = params["status"] + + image = torch.rand(*image_size) # Create a random image tensor + + if status == "Passed": + tiles = tile_crop(image, tile_size) + expected_output_shape = params["expected_output_shape"] + assert ( + tiles.shape == expected_output_shape + ), f"Expected shape {expected_output_shape} but got {tiles.shape}" + + # check if first and last tile matches the image + first_tile = image[..., :tile_size, :tile_size] + last_tile = image[..., -tile_size:, -tile_size:] + assert torch.equal( + tiles[0], first_tile + ), "Expected first tile to match the image" + assert torch.equal( + tiles[-1], last_tile + ), "Expected last tile to match the image" + + elif status == "Failed": + with pytest.raises(Exception) as exc_info: + tile_crop(image, tile_size) + expected_error = params["error"] + actual_error = str(exc_info.value) + assert ( + str(exc_info.value) == params["error"] + ), f"Expected error message '{expected_error}' but got '{actual_error}'" diff --git a/training/tests/torchtune/modules/transforms/test_transforms.py b/training/tests/torchtune/modules/transforms/test_transforms.py new file mode 100644 index 0000000..0436a34 --- /dev/null +++ b/training/tests/torchtune/modules/transforms/test_transforms.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.modules.transforms import VisionCrossAttentionMask + + +IMAGE_TOKEN_ID = 1 + + +class TestVisionCrossAttentionMask: + @pytest.fixture + def num_tiles(self): + return 2 + + @pytest.fixture + def tile_size(self): + return 4 + + @pytest.fixture + def patch_size(self): + return 2 + + @pytest.fixture + def image_num_tokens(self, num_tiles, tile_size, patch_size): + return ((tile_size // patch_size) ** 2 + 1) * num_tiles + + @pytest.fixture + def tokens(self): + # This tests image tokens not at start, consecutive images, and image + # with text until end. + # text = 2, image = 1 + return [2, 2, IMAGE_TOKEN_ID, IMAGE_TOKEN_ID, 2, 2, IMAGE_TOKEN_ID, 2, 2] + + @pytest.fixture + def images(self, num_tiles, tokens): + n_img = len([token_id for token_id in tokens if token_id == IMAGE_TOKEN_ID]) + n_channels = 3 + tile_size = 2 + return [ + torch.ones(num_tiles, n_channels, tile_size, tile_size) + for _ in range(n_img) + ] + + @pytest.fixture + def cross_attn_mask_transform(self, tile_size, patch_size): + # patches per tile = 4 + return VisionCrossAttentionMask( + tile_size=tile_size, + patch_size=patch_size, + image_token_id=IMAGE_TOKEN_ID, + ) + + def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens): + actual = cross_attn_mask_transform._get_image_attention_intervals(tokens) + expected = [[2, 6], [3, 6], [6, 9]] + assert actual == expected + + def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens): + sample = {"tokens": tokens, "images": images} + dummy_kwargs = {"hello": 8} + sample.update(dummy_kwargs) + actual = cross_attn_mask_transform(sample) + expected = [ + torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool) + for _ in range(len(images)) + ] + expected[0][2:6, :] = True + expected[1][3:6, :] = True + expected[2][6:9, :] = True + for i in range(len(images)): + torch.testing.assert_close(actual["encoder_mask"][i], expected[i]) + torch.testing.assert_close(actual["images"][i], images[i]) + + assert actual["tokens"] == tokens + assert actual["hello"] == dummy_kwargs["hello"] diff --git a/training/tests/torchtune/utils/__init__.py b/training/tests/torchtune/utils/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/tests/torchtune/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/tests/torchtune/utils/test_argparse.py b/training/tests/torchtune/utils/test_argparse.py new file mode 100644 index 0000000..67018d8 --- /dev/null +++ b/training/tests/torchtune/utils/test_argparse.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from unittest import mock + +import pytest +from omegaconf import OmegaConf + +from torchtune.utils import TuneRecipeArgumentParser + +_CONFIG = {"a": 1, "b": 2} + + +class TestArgParse: + @pytest.fixture + def parser(self): + parser = TuneRecipeArgumentParser("Test parser") + return parser + + @mock.patch("torchtune.utils.argparse.OmegaConf.load", return_value=_CONFIG) + def test_parse_known_args(self, mock_load, parser): + """ + Test that the parser can load a config and override parameters provided on CLI. + The actual load is mocked to return the test config above. + """ + config_args, cli_args = parser.parse_known_args( + ["--config", "test.yaml", "b=3", "c=4"] + ) + assert config_args.a == 1, f"a == {config_args.a} not 1 as set in the config." + assert config_args.b == 2, f"b == {config_args.b} not 2 as set in the config." + + cli_kwargs = OmegaConf.from_dotlist(cli_args) + assert ( + cli_kwargs.b == 3 + ), f"b == {cli_kwargs.b} not 3 as set in the command args." + assert ( + cli_kwargs.c == 4 + ), f"c == {cli_kwargs.c} not 4 as set in the command args." + + with pytest.raises(ValueError, match="Additional flag arguments not supported"): + _ = parser.parse_known_args( + ["--config", "test.yaml", "--b", "3"], + ) diff --git a/training/tests/torchtune/utils/test_checkpointer.py b/training/tests/torchtune/utils/test_checkpointer.py new file mode 100644 index 0000000..bc81b62 --- /dev/null +++ b/training/tests/torchtune/utils/test_checkpointer.py @@ -0,0 +1,791 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json + +from pathlib import Path +from typing import Tuple + +import pytest +import torch +from torch import randn + +from torchtune.models import gemma, llama2, mistral +from torchtune.modules.peft.peft_utils import ( + get_adapter_params, + get_lora_module_names, + validate_missing_and_unexpected_for_lora, +) + +from torchtune.utils._checkpointing import FullModelHFCheckpointer +from torchtune.utils._checkpointing._checkpointer_utils import safe_torch_load +from torchtune.utils.constants import ADAPTER_CONFIG, ADAPTER_KEY +from torchtune.utils.seed import set_seed + +_VOCAB_SIZE = 100 +_DIM = 64 +_HIDDEN_DIM = 256 +_NUM_HEADS = 4 +_NUM_KV_HEADS = 4 +_HEAD_DIM = 16 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestHFLlama2FullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict_1(self, weight_dtype): + """ + State dict for a HF format checkpoint. This state dict is "complete" and + can be loaded into a torchtune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.rotary_emb.inv_freq": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": torch.randn(_DIM, dtype=weight_dtype), + "lm_head.weight": torch.randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def state_dict_2(self, weight_dtype): + """ + State dict for a HF format checkpoint. This state dict is "incomplete" and + should be used along with ``state_dict_1`` to test multi-file checkpointing. Specifically + it's missing the embedding, norm and lm_head keys. + """ + state_dict = { + "model.layers.1.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.1.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.rotary_emb.inv_freq": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.1.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.1.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + } + return state_dict + + @pytest.fixture + def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2): + """ + Fixture which creates two checkpoint files for the Llama2 model. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts are structured in such a way that both single file and + multiple file checkpoints can be tested. + * The first checkpoint contains layer0 + embed + norm + lm_head keys + and can be tested in isolation + * The second checkpoint contains all layer1 keys and should be tested + in the multiple file checkpoint test along with the first checkpoint + + The model corresponds to the following config: + * vocab_size: 100 + * num_layers: 1 for single checkpoint and 2 for multiple checkpoint + * num_heads: 4 + * num_kv_heads: 4 + * embed_dim: 64 + * max_seq_len: 128 + """ + checkpoint_file_1 = tmp_path / "llama2_hf_checkpoint_01.pt" + checkpoint_file_2 = tmp_path / "llama2_hf_checkpoint_02.pt" + + torch.save(state_dict_1, checkpoint_file_1) + torch.save(state_dict_2, checkpoint_file_2) + + config = { + "hidden_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 4, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return (checkpoint_file_1, checkpoint_file_2) + + @pytest.fixture + def single_file_checkpointer( + self, llama2_hf_checkpoints, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file, _ = llama2_hf_checkpoints + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="LLAMA2", + output_dir=tmp_path, + ) + + @pytest.fixture + def multi_file_checkpointer( + self, llama2_hf_checkpoints, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file_1, checkpoint_file_2 = llama2_hf_checkpoints + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file_1, checkpoint_file_2], + model_type="LLAMA2", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + llama2_hf_checkpoints: Tuple[Path, Path], + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the llama2 torchtune implementation + * Saved checkpoint keys match the original checkpoint + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file, _ = llama2_hf_checkpoints + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + state_dict = single_file_checkpointer.load_checkpoint() + + # Check that we've loaded all the keys; We ignore inv_freq as is standard practice + assert len(state_dict["model"].keys()) + 1 == len(orig_state_dict.keys()) + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = llama2.llama2( + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + # We ignore inv_freq as is standard practice and so output dict will have one less key + assert len(output_state_dict.keys()) + 1 == len(orig_state_dict.keys()) + + def test_save_load_checkpoint_multiple_file( + self, + multi_file_checkpointer: FullModelHFCheckpointer, + llama2_hf_checkpoints: Tuple[Path, Path], + ): + """ + Test ``load_checkpoint`` method within the FullModelCheckpointer for multiple + checkpoint file. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the llama2 torchtune implementation + """ + # Read the state dict directly from files + checkpoint_file_1, checkpoint_file_2 = llama2_hf_checkpoints + orig_state_dict_1 = safe_torch_load(checkpoint_file_1) + orig_state_dict_2 = safe_torch_load(checkpoint_file_2) + + # merged state dict from checkpointer + state_dict = multi_file_checkpointer.load_checkpoint() + + # We ignore inv_freq as is standard practice + assert len(state_dict["model"].keys()) + 2 == len( + orig_state_dict_1.keys() + ) + len(orig_state_dict_2.keys()) + + # the keys in the weight_map should match up with the keys in the weight_map + for key in orig_state_dict_1.keys(): + if "inv_freq" in key: + continue + assert key in multi_file_checkpointer._weight_map + + for key in orig_state_dict_2.keys(): + if "inv_freq" in key: + continue + assert key in multi_file_checkpointer._weight_map + + # finally loading into the model should work + model = llama2.llama2( + vocab_size=_VOCAB_SIZE, + num_layers=2, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + multi_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file_1 = Path.joinpath(checkpoint_file_1.parent, "hf_model_0001_1.pt") + output_file_2 = Path.joinpath(checkpoint_file_2.parent, "hf_model_0002_1.pt") + output_state_dict_1 = safe_torch_load(output_file_1) + output_state_dict_2 = safe_torch_load(output_file_2) + + assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys()) + assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys()) + + def test_load_save_adapter_only( + self, tmp_path, single_file_checkpointer, llama2_hf_checkpoints + ): + """ """ + state_dict = single_file_checkpointer.load_checkpoint() + + with pytest.raises( + ValueError, match="Adapter checkpoint not found in state_dict" + ): + single_file_checkpointer.save_checkpoint( + state_dict, epoch=2, adapter_only=True + ) + + state_dict[ADAPTER_KEY] = {} + single_file_checkpointer.save_checkpoint(state_dict, epoch=2, adapter_only=True) + + output_file_1 = Path.joinpath(tmp_path, "hf_model_0001_2.pt") + output_file_2 = Path.joinpath(tmp_path, "adapter_2.pt") + + with pytest.raises(ValueError, match="Unable to load checkpoint from"): + output_state_dict_1 = safe_torch_load(output_file_1) + + output_state_dict_2 = safe_torch_load(output_file_2) + # Check that the empty adapter we saved is the one loaded succesfully + assert len(output_state_dict_2.keys()) == 0 + + def test_save_checkpoint_in_peft_format( + self, + single_file_checkpointer: FullModelHFCheckpointer, + llama2_hf_checkpoints: Tuple[Path, Path], + ): + """ + Test save_checkpoint method within the FullModelCheckpointer for + integration with HF PEFT (i.e. save_in_peft_format=True). + + We test that: + * The file adapter_config.json contains the fields required by PEFT + and the correct values + * The state dict keys of the saved adapter checkpoint are remapped as expected + * The state dict values of the saved adapter checkpoint (after key remapping) + match those in torchtune for parameters that are not permuted by HF + # The state dict values of the saved adapter checkpoint (after key remapping) + do not match those in torchtune for parameters that are permuted by HF, but the + sums along the dimension of permutation match + """ + + # Define LoRA params for this test + lora_attn_modules = ["q_proj", "output_proj"] + apply_lora_to_mlp = True + apply_lora_to_output = True + lora_rank = 4 + lora_alpha = 8 + + checkpoint_file, _ = llama2_hf_checkpoints + state_dict = single_file_checkpointer.load_checkpoint() + + # Build LoRA Llama2 model and load in base model weights + model = llama2.lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + max_seq_len=128, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + missing, unexpected = model.load_state_dict(state_dict["model"], strict=False) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + base_missing=missing, + base_unexpected=unexpected, + ) + + # LoRA B params are zero-initialized, randomly initialize them to make + # the test of their permutation on checkpoint save nontrivial + lora_b_sd = { + k: torch.randn_like(v) + for k, v in model.state_dict().items() + if "lora_b" in k + } + model.load_state_dict(lora_b_sd, strict=False) + + # Construct the adapter weights and config and save using checkpointer + adapter_params = get_adapter_params(model) + adapter_key_filter = lambda x: x in adapter_params + expected_adapter_state_dict = { + k: v for k, v in model.state_dict().items() if adapter_key_filter(k) + } + adapter_config = { + "r": lora_rank, + "lora_alpha": lora_alpha, + "target_modules": get_lora_module_names( + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + ), + "peft_type": "LORA", + } + state_dict.update({ADAPTER_KEY: expected_adapter_state_dict}) + state_dict.update({ADAPTER_CONFIG: adapter_config}) + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Load saved adapter weights and config from file for comparison + adapter_weights_file = Path.joinpath( + checkpoint_file.parent, "adapter_model.bin" + ) + actual_adapter_state_dict = safe_torch_load(adapter_weights_file) + + adapter_config_file = Path.joinpath( + checkpoint_file.parent, "adapter_config.json" + ) + with open(adapter_config_file, "r") as f: + adapter_config = json.load(f) + + expected_target_modules = [ + "down_proj", + "gate_proj", + "lm_head", + "o_proj", + "q_proj", + "up_proj", + ] + assert sorted(adapter_config["target_modules"]) == expected_target_modules + + # Map PEFT keys back to torchtune keys + peft_to_tt = { + "o_proj": "output_proj", + "gate_proj": "w1", + "down_proj": "w2", + "up_proj": "w3", + "lm_head": "output", + } + for k, v in actual_adapter_state_dict.items(): + new_k = k.replace("base_model.model.", "").replace("self_attn", "attn") + if "lm_head" not in new_k: + new_k = new_k.replace("model.", "") + for kk, vv in peft_to_tt.items(): + if kk in k: + new_k = new_k.replace(kk, vv) + new_k = new_k.replace("lora_A", "lora_a").replace("lora_B", "lora_b") + + # LoRA B matrix for Q should not match due to Q and K permutation + # However, since they're permuted along embed dim, their sum along that axis should match + if "lora_b" in new_k and "q_proj" in new_k: + assert not torch.allclose( + actual_adapter_state_dict[k], expected_adapter_state_dict[new_k] + ) + torch.testing.assert_close( + actual_adapter_state_dict[k].sum(dim=0), + expected_adapter_state_dict[new_k].sum(dim=0), + ) + + # All other matrices should match exactly + if "lora_b" not in new_k: + torch.testing.assert_close( + actual_adapter_state_dict[k], expected_adapter_state_dict[new_k] + ) + + +class TestHFMistralRewardModelFullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict for a HF format mistral reward model checkpoint. This state dict is + "complete" and can be loaded into a torchtune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + "score.weight": randn(1, _DIM, dtype=weight_dtype), + # adding bias to ensure it doesn't cause an unexpected key + "score.bias": randn(1, _DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict): + """ + Fixture which creates a checkpoint file for the Mistral reward model. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts supports testing for a single-file checkpoint. + Multiple file checkpoints are already tested for Llama2. + * The checkpoint contains layer0 + embed + norm + score keys + and can be tested in isolation + + The model corresponds to the following config: + * num_layers: 1 + * num_heads: 4 + * num_kv_heads: 4 + * embed_dim: 64 + * max_seq_len: 128 + * num_classes: 1 + * intermediate_dim: 256 + + """ + checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt" + + torch.save(state_dict, checkpoint_file) + + config = { + "hidden_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "num_classes": 1, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return checkpoint_file + + @pytest.fixture + def single_file_checkpointer( + self, mistral_reward_model_hf_checkpoint, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file = mistral_reward_model_hf_checkpoint + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="REWARD", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + mistral_reward_model_hf_checkpoint: Path, + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file for a mistral reward model. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the `mistral_classifier` torchtune implementation + * Saved checkpoint keys match the original checkpoint + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file = mistral_reward_model_hf_checkpoint + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + state_dict = single_file_checkpointer.load_checkpoint() + # Check that we've loaded all the keys minus the output bias + assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) - 1 + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key or "output.bias" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = mistral.mistral_classifier( + num_classes=1, + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + intermediate_dim=_HIDDEN_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1 + + +class TestHFGemmaFullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict for a HF format Gemma checkpoint. This state dict is + "complete" and can be loaded into a TorchTune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _NUM_HEADS * _HEAD_DIM, dtype=weight_dtype + ), + # setting num_kv_heads to 1 + "model.layers.0.self_attn.k_proj.weight": randn( + _HEAD_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _HEAD_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _NUM_HEADS * _HEAD_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def gemma_hf_checkpoint(self, tmp_path, state_dict): + """ + Fixture which creates a checkpoint file for Gemma. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts supports testing for a single-file checkpoint. + Multiple file checkpoints are already tested for Llama2. + + The model corresponds to the following config: + * num_layers: 1 + * num_heads: 4 + * num_kv_heads: 1 + * embed_dim: 64 + * max_seq_len: 128 + * num_classes: 1 + * intermediate_dim: 256 + * head_dim : 16 + + """ + checkpoint_file = tmp_path / "gemma_hf_checkpoint.pt" + + torch.save(state_dict, checkpoint_file) + + config = { + "hidden_size": _DIM, + "num_attention_heads": _NUM_HEADS, + "num_key_value_heads": 1, + "head_dim": _HEAD_DIM, + "intermediate_size": _HIDDEN_DIM, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return checkpoint_file + + @pytest.fixture + def single_file_checkpointer( + self, gemma_hf_checkpoint, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file = gemma_hf_checkpoint + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="GEMMA", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + gemma_hf_checkpoint: Path, + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file for Gemma. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the `gemma` TorchTune implementation + * lm_head weights are tied to the embed_tokens weights during saving + * lmhead weights are popped during loading + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file = gemma_hf_checkpoint + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + + state_dict = single_file_checkpointer.load_checkpoint() + assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = gemma.gemma( + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + head_dim=_HEAD_DIM, + num_kv_heads=1, + embed_dim=_DIM, + intermediate_dim=_HIDDEN_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) + + +class TestCheckpointerUtils: + @pytest.fixture + def model_checkpoint(self, tmp_path): + """ + Fixture which creates a checkpoint file for testing checkpointer utils. + """ + checkpoint_file = tmp_path / "model_checkpoint_01.pt" + + state_dict = { + "token_embeddings.weight": torch.ones(1, 10), + "output.weight": torch.ones(1, 10), + } + + torch.save(state_dict, checkpoint_file) + + return checkpoint_file + + @pytest.mark.parametrize("weights_only", [True, False]) + def test_safe_torch_load(self, model_checkpoint, weights_only): + state_dict = safe_torch_load(Path(model_checkpoint), weights_only) + + assert "token_embeddings.weight" in state_dict + assert "output.weight" in state_dict + + assert state_dict["token_embeddings.weight"].shape[1] == 10 + assert state_dict["output.weight"].shape[0] == 1 diff --git a/training/tests/torchtune/utils/test_collate.py b/training/tests/torchtune/utils/test_collate.py new file mode 100644 index 0000000..31403f2 --- /dev/null +++ b/training/tests/torchtune/utils/test_collate.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import torch + +from torchtune.utils.collate import padded_collate + + +class TestBatchPadSequence: + def test_padded_collate(self): + """ + Tests that shorter input, label sequences are padded to the max seq len. + """ + padding_idx = -8 + ignore_idx = -9 + token_pairs = [ + { + "tokens": [1, 2, 3], + "labels": [4, 5, 6], + }, + { + "tokens": [7], + "labels": [10], + }, + ] + padded = padded_collate( + batch=token_pairs, + padding_idx=padding_idx, + ignore_idx=ignore_idx, + ) + padded_input = padded["tokens"][1] + padded_label = padded["labels"][1] + torch.testing.assert_close( + padded_input, torch.tensor([7, padding_idx, padding_idx]) + ) + torch.testing.assert_close( + padded_label, torch.tensor([10, ignore_idx, ignore_idx]) + ) diff --git a/training/tests/torchtune/utils/test_device.py b/training/tests/torchtune/utils/test_device.py new file mode 100644 index 0000000..302b6d1 --- /dev/null +++ b/training/tests/torchtune/utils/test_device.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from unittest import mock +from unittest.mock import patch + +import pytest + +import torch +from torchtune.utils._device import ( + _get_device_type_from_env, + _setup_cuda_device, + get_device, +) + + +class TestDevice: + + cuda_available: bool = torch.cuda.is_available() + + @patch("torch.cuda.is_available", return_value=False) + def test_get_cpu_device(self, mock_cuda): + devices = [None, "cpu", "meta"] + expected_devices = [ + torch.device("cpu"), + torch.device("cpu"), + torch.device("meta"), + ] + for device, expected_device in zip(devices, expected_devices): + device = get_device(device) + assert device == expected_device + assert device.index is None + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + def test_get_gpu_device(self) -> None: + device_idx = torch.cuda.device_count() - 1 + assert device_idx >= 0 + with mock.patch.dict(os.environ, {"LOCAL_RANK": str(device_idx)}, clear=True): + device = get_device() + assert device.type == "cuda" + assert device.index == device_idx + assert device.index == torch.cuda.current_device() + + # Test that we raise an error if the device index is specified on distributed runs + if device_idx > 0: + with pytest.raises( + RuntimeError, + match=f"Device specified is cuda:0 but was assigned cuda:{device_idx}", + ): + device = get_device("cuda:0") + + invalid_device_idx = device_idx + 10 + with mock.patch.dict(os.environ, {"LOCAL_RANK": str(invalid_device_idx)}): + with pytest.raises( + RuntimeError, + match="The local rank is larger than the number of available GPUs", + ): + device = get_device("cuda") + + # Test that we fall back to 0 if LOCAL_RANK is not specified + device = torch.device(_get_device_type_from_env()) + device = _setup_cuda_device(device) + assert device.type == "cuda" + assert device.index == 0 + assert device.index == torch.cuda.current_device() diff --git a/training/tests/torchtune/utils/test_distributed.py b/training/tests/torchtune/utils/test_distributed.py new file mode 100644 index 0000000..38c82f8 --- /dev/null +++ b/training/tests/torchtune/utils/test_distributed.py @@ -0,0 +1,521 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import copy +from itertools import chain + +import pytest +import torch +import torch.nn as nn +from packaging import version +from tests.test_utils import gpu_test, single_box_init +from torch.distributed import launcher + +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from torch.testing._internal.common_fsdp import FSDPTest, MLP +from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune import modules, utils +from torchtune.models.llama2._component_builders import llama2, lora_llama2 +from torchtune.models.llama3._component_builders import llama3 +from torchtune.modules import TransformerDecoderLayer +from torchtune.modules.peft import LoRALinear +from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params + + +class TestDistributed: + def test_init_distributed(self) -> None: + """Integration test to confirm consistency across device initialization utilities.""" + distributed = utils.init_distributed() + assert ( + not distributed + ), "Should return False as there are no distributed environment variables" + + @staticmethod + def _test_worker_fn(init_pg_explicit: bool) -> None: + """ + Integration test to confirm distributed initialization and consistency with process group backend utilities. + """ + if init_pg_explicit: + torch.distributed.init_process_group(backend="gloo") + if not torch.distributed.is_initialized(): + utils.init_distributed(backend="gloo") + if not torch.distributed.is_initialized(): + raise AssertionError("Expected torch.distributed to be initialized") + pg_backend = torch.distributed.get_backend() + assert ( + pg_backend == "gloo" + ), f"Expected 'gloo' backend, but received {pg_backend}" + + @staticmethod + def _test_world_size_with_cpu_device(expected_world_size: int) -> None: + utils.init_distributed(backend="gloo") + world_size, _ = utils.get_world_size_and_rank() + if world_size != expected_world_size: + raise AssertionError( + f"Expected different world size: received {world_size}, expected {expected_world_size}" + ) + + def _test_launch_worker( + self, + get_pet_launch_config, + num_processes: int, + init_pg_explicit: bool, + ) -> None: + lc = get_pet_launch_config(num_processes) + launcher.elastic_launch(lc, entrypoint=self._test_worker_fn)(init_pg_explicit) + + def test_init_from_env_no_dup(self, get_pet_launch_config) -> None: + self._test_launch_worker(get_pet_launch_config, 2, init_pg_explicit=False) + # trivial test case to ensure test passes with no exceptions + assert True + + def test_init_from_env_dup(self, get_pet_launch_config) -> None: + self._test_launch_worker(get_pet_launch_config, 2, init_pg_explicit=True) + # trivial test case to ensure test passes with no exceptions + assert True + + def test_world_size_with_cpu(self, get_pet_launch_config) -> None: + desired_world_size = 4 + lc = get_pet_launch_config(desired_world_size) + launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)( + desired_world_size + ) + + def test_validate_no_params_on_meta_device(self) -> None: + with torch.device("meta"): + model = torch.nn.Linear(3, 3) + + with pytest.raises(RuntimeError, match="Unexpected param or buffer"): + utils.validate_no_params_on_meta_device(model) + + # Test model with only buffer + model = torch.nn.Linear(3, 3) + buffer = torch.ones(1, device="meta") + model.register_buffer("buffer", buffer) + + with pytest.raises(RuntimeError, match="Unexpected param or buffer"): + utils.validate_no_params_on_meta_device(model) + + def test_get_fsdp_wrap_policies(self) -> None: + with single_box_init(): + llama3_policy = utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=True, + modules_to_wrap={modules.TransformerDecoderLayer}, + ) + l3 = llama3( + vocab_size=64, + num_layers=1, + num_heads=4, + num_kv_heads=4, + embed_dim=64, + max_seq_len=128, + ) + wrapped_l3 = FSDP( + l3, auto_wrap_policy=llama3_policy, device_id=torch.device("cpu") + ) + # Ensure embedding, output proj, and transformer decoder blocks are wrapped + assert isinstance(wrapped_l3.tok_embeddings, FSDP) + assert isinstance(wrapped_l3.output, FSDP) + for layer in wrapped_l3.layers: + assert isinstance(layer, FSDP) + + llama2_policy = utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=False, + modules_to_wrap={modules.TransformerDecoderLayer}, + ) + l2 = llama2( + vocab_size=64, + num_layers=1, + num_heads=4, + num_kv_heads=4, + embed_dim=64, + max_seq_len=128, + ) + wrapped_l2 = FSDP( + l2, auto_wrap_policy=llama2_policy, device_id=torch.device("cpu") + ) + # Ensure embedding, output proj, and transformer decoder blocks are not wrapped + assert not isinstance(wrapped_l2.tok_embeddings, FSDP) + assert not isinstance(wrapped_l2.output, FSDP) + # Ensure transformer decoder blocks are wrapped + for layer in wrapped_l2.layers: + assert isinstance(layer, FSDP) + + +N_LAYERS = 3 +IN_DIM = 5 +OUT_DIM = 10 +VOCAB_SIZE = 50 +NUM_HEADS = 4 +NUM_KV_HEADS = 2 +EMBED_DIM = 64 +MAX_SEQ_LEN = 64 + + +def _get_n_lora_and_tformer_layers(model): + num_lora_ab = 0 + num_transformer_layers = 0 + for module in model.modules(): + if isinstance(module, LoRALinear): + num_nested_linears = len( + [m for m in module.modules() if isinstance(m, nn.Linear)] + ) + num_lora_ab += num_nested_linears + if isinstance(module, TransformerDecoderLayer): + num_transformer_layers += 1 + + return num_lora_ab, num_transformer_layers + + +# TODO: figure out a permanent home for FSDP + LoRA code +class TestLoRAFSDP: + def test_lora_fsdp_wrap(self): + with torch.device("meta"): + model = lora_llama2( + lora_attn_modules=["q_proj", "v_proj"], + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + lora_rank=4, + lora_alpha=1.0, + ) + + adapter_params = get_adapter_params(model) + set_trainable_params(model, adapter_params) + num_lora_ab, num_transformer_layers = _get_n_lora_and_tformer_layers(model) + with single_box_init(): + lora_wrap_policy = utils.lora_fsdp_wrap_policy( + modules_to_wrap={TransformerDecoderLayer} + ) + utils.prepare_model_for_fsdp_with_meta_device(model) + wrapped_lora = FSDP( + model, + auto_wrap_policy=lora_wrap_policy, + device_id=torch.device("cpu"), + ) + + # After FSDP wrap, nothing should be left on meta device, and LoRA params + # should be initialized. + for p in chain(wrapped_lora.parameters(), wrapped_lora.buffers()): + assert not p.is_meta + + for m in wrapped_lora.modules(): + if isinstance(m, LoRALinear): + torch.testing.assert_close( + m.lora_b.weight, torch.zeros_like(m.lora_b.weight) + ) + + # Total # FSDP modules should be num_transformer + num_lora_ab + 1 + total_fsdp_submodules = len([m for m in FSDP.fsdp_modules(wrapped_lora)]) + assert total_fsdp_submodules == (num_lora_ab + num_transformer_layers + 1) + # LoRA a & b linears should be individually wrapped. + # And TransformerDecoderLayers should be individually wrapped. + for fsdp_submodule in FSDP.fsdp_modules(wrapped_lora): + if isinstance(fsdp_submodule.module, nn.Linear): + num_lora_ab -= 1 + elif isinstance(fsdp_submodule.module, TransformerDecoderLayer): + num_transformer_layers -= 1 + assert num_lora_ab == 0 + assert num_transformer_layers == 0 + + def test_lora_meta_device_init_fsdp(self): + with torch.device("meta"): + lora = lora_llama2( + lora_attn_modules=["q_proj", "v_proj"], + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + lora_rank=4, + lora_alpha=8, + ) + utils.prepare_model_for_fsdp_with_meta_device(lora) + for m in lora.modules(): + m.to_empty(device=torch.device("cpu"), recurse=False) + m.reset_parameters() + # No params should be left on meta device + for n, p in lora.named_parameters(): + assert not p.is_meta, f"parameter {n} is still on meta device!" + # Neither should buffers + for n, b in lora.named_buffers(): + assert not b.is_meta, f"buffer {n} is still on meta device!" + + +class TestFullyShardState(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", + reason="torch >= 2.4 required", + ) + def test_lora_state_dict(self): + rank = self.rank + is_rank_zero = rank == 0 + mlp_dim = 4 + epochs = 5 + torch.manual_seed(42) + # base_model is simple DDP + with torch.device("cuda"): + base_model = nn.Sequential( + MLP(mlp_dim), + nn.Sequential(MLP(mlp_dim), nn.Linear(mlp_dim, mlp_dim)), + MLP(mlp_dim), + ) + base_optim = torch.optim.Adam( + base_model.parameters(), weight_decay=0.01, lr=0.01 + ) + + fsdp_model_to_save = copy.deepcopy(base_model) + for module in fsdp_model_to_save: + fully_shard(module) + fully_shard(fsdp_model_to_save) + fsdp_optim_to_save = torch.optim.Adam( + fsdp_model_to_save.parameters(), weight_decay=0.01, lr=0.01 + ) + + # inp is different for each rank + torch.manual_seed(42 + rank) + + # test get full state dict + for _ in range(epochs): + inp = torch.randn((2, mlp_dim), device="cuda") + base_model(inp).sum().backward() + for param in base_model.parameters(): + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.AVG + ) + base_optim.step() + base_optim.zero_grad() + fsdp_model_to_save(inp).sum().backward() + fsdp_optim_to_save.step() + fsdp_optim_to_save.zero_grad() + expected_model_sd = base_model.state_dict() + expected_optim_sd = base_optim.state_dict() + model_full_sd = utils.get_full_model_state_dict( + fsdp_model_to_save, is_rank_zero + ) + optim_full_sd = utils.get_full_optimizer_state_dict( + fsdp_optim_to_save, + is_rank_zero, + ) + if is_rank_zero: + self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys())) + for key, value in model_full_sd.items(): + self.assertEqual(value, expected_model_sd[key]) + self.assertEqual(len(optim_full_sd["param_groups"]), 1) + self.assertEqual( + len(optim_full_sd["param_groups"]), + len(expected_optim_sd["param_groups"]), + ) + self.assertEqual( + len(optim_full_sd["param_groups"][0].keys()), + len(expected_optim_sd["param_groups"][0].keys()), + ) + for key, value in optim_full_sd["param_groups"][0].items(): + if key == "params": + self.assertEqual( + len(value), len(expected_optim_sd["param_groups"][0][key]) + ) + else: + self.assertEqual(value, expected_optim_sd["param_groups"][0][key]) + self.assertEqual( + len(optim_full_sd["state"].keys()), + len(expected_optim_sd["state"].keys()), + ) + for actual, expected in zip( + optim_full_sd["state"].values(), expected_optim_sd["state"].values() + ): + self.assertEqual(actual, expected) + else: + self.assertEqual(len(model_full_sd), 0) + self.assertEqual(len(optim_full_sd), 0) + + # test set full state dict + with torch.device("meta"): + fsdp_model_to_load = nn.Sequential( + MLP(mlp_dim), + nn.Sequential(MLP(mlp_dim), nn.Linear(mlp_dim, mlp_dim)), + MLP(mlp_dim), + ) + for module in fsdp_model_to_load: + fully_shard(module) + fully_shard(fsdp_model_to_load) + utils.load_from_full_model_state_dict( + fsdp_model_to_load, + copy.deepcopy(base_model.state_dict()), + torch.device("cuda"), + is_rank_zero, + ) + fsdp_optim_to_load = torch.optim.Adam( + fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01 + ) + utils.load_from_full_optimizer_state_dict( + fsdp_optim_to_load, + # mimic mmap=True where every rank see full SD + copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)), + torch.device("cuda"), + ) + for _ in range(epochs): + inp = torch.randn((2, mlp_dim), device="cuda") + fsdp_model_to_load(inp).sum().backward() + fsdp_model_to_save(inp).sum().backward() + fsdp_optim_to_load.step() + fsdp_optim_to_save.step() + fsdp_optim_to_load.zero_grad() + fsdp_optim_to_save.zero_grad() + sharded_optim_sd = fsdp_optim_to_load.state_dict() + expected_sharded_optim_sd = fsdp_optim_to_save.state_dict() + self.assertEqual( + sharded_optim_sd["param_groups"], + expected_sharded_optim_sd["param_groups"], + ) + self.assertEqual( + set(sharded_optim_sd["state"].keys()), + set(expected_sharded_optim_sd["state"].keys()), + ) + for key, value in sharded_optim_sd["state"].items(): + self.assertEqual(value, expected_sharded_optim_sd["state"][key]) + + sharded_model_sd = fsdp_model_to_load.state_dict() + expected_sharded_model_sd = fsdp_model_to_save.state_dict() + self.assertEqual( + set(sharded_model_sd.keys()), set(expected_sharded_model_sd.keys()) + ) + for key, value in sharded_model_sd.items(): + self.assertEqual(value, expected_sharded_model_sd[key]) + + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", + reason="torch >= 2.4 required", + ) + @gpu_test(gpu_count=2) + def test_qlora_state_dict(self): + self.run_subtests( + { + "enable_activation_checkpointing": [False, True], + }, + self._test_qlora_state_dict, + ) + + def _test_qlora_state_dict(self, enable_activation_checkpointing: bool): + is_rank_zero = self.rank == 0 + torch.manual_seed(42) + kwargs = { + "lora_attn_modules": ["q_proj", "v_proj", "k_proj", "output_proj"], + "apply_lora_to_mlp": True, + "apply_lora_to_output": False, + "vocab_size": 1024, + "num_layers": 3, + "num_heads": 4, + "num_kv_heads": 2, + "embed_dim": 1024, + "max_seq_len": 64, + "lora_rank": 4, + "lora_alpha": 1.0, + "quantize_base": True, + } + # single-device model as groundtruth + with torch.device("cuda"): + base_model = lora_llama2(**kwargs) + set_trainable_params(base_model, get_adapter_params(base_model)) + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + base_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + # fsdp model for saving state dict + fsdp_model_to_save = copy.deepcopy(base_model) + for m in fsdp_model_to_save.modules(): + if enable_activation_checkpointing: + if isinstance(m, CheckpointWrapper): + fully_shard(m) + else: + if isinstance(m, modules.TransformerDecoderLayer): + fully_shard(m) + fully_shard(fsdp_model_to_save) + + # one forward pass for lazy init + torch.manual_seed(42 + self.rank) + inp = torch.randint( + low=0, + high=kwargs["vocab_size"], + size=(2, kwargs["max_seq_len"]), + device="cuda", + ) + base_model(inp) + fsdp_model_to_save(inp) + + expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()} + model_full_sd = utils.get_full_model_state_dict( + fsdp_model_to_save, is_rank_zero + ) + if is_rank_zero: + self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys())) + for key, value in model_full_sd.items(): + self.assertEqual(value, expected_model_sd[key]) + + # fsdp model for loading tate dict + torch.manual_seed(42) + with torch.device("meta"): + fsdp_model_to_load = lora_llama2(**kwargs) + set_trainable_params(fsdp_model_to_load, get_adapter_params(fsdp_model_to_load)) + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + fsdp_model_to_load, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + # init rope since it's not covered in state dict + for m in fsdp_model_to_load.modules(): + if isinstance(m, modules.RotaryPositionalEmbeddings): + m.reset_parameters() + for m in fsdp_model_to_load.modules(): + if enable_activation_checkpointing: + if isinstance(m, CheckpointWrapper): + fully_shard(m) + else: + if isinstance(m, modules.TransformerDecoderLayer): + fully_shard(m) + fully_shard(fsdp_model_to_load) + utils.load_from_full_model_state_dict( + fsdp_model_to_load, expected_model_sd, torch.device("cuda"), is_rank_zero + ) + fsdp_model_to_load(inp) + sharded_model_sd = fsdp_model_to_load.state_dict() + expected_sharded_model_sd = fsdp_model_to_save.state_dict() + self.assertEqual( + set(sharded_model_sd.keys()), set(expected_sharded_model_sd.keys()) + ) + for key, value in sharded_model_sd.items(): + if isinstance(value._local_tensor, NF4Tensor): + self.assertEqual( + value._local_tensor.get_original_weight(), + expected_sharded_model_sd[key]._local_tensor.get_original_weight(), + ) + else: + self.assertEqual(value, expected_sharded_model_sd[key]) + + def _broadcast_full_state_dict(self, full_sd): + result = [] + if torch.distributed.get_rank() == 0: + result.append(full_sd) + else: + result.append(None) + torch.distributed.broadcast_object_list(result, src=0) + return result[0] diff --git a/training/tests/torchtune/utils/test_generation.py b/training/tests/torchtune/utils/test_generation.py new file mode 100644 index 0000000..15c4a33 --- /dev/null +++ b/training/tests/torchtune/utils/test_generation.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from tests.test_utils import fixed_init_model + +from torchtune import utils +from torchtune.models.llama2 import llama2 +from torchtune.utils._generation import sample + + +class TestTextGenerate: + """ + Test class for text generation functionality. + """ + + @pytest.fixture + def generation_model(self, dtype=torch.float32): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.setup_caches(batch_size=1, dtype=dtype) + model.eval() + return model + + @pytest.fixture + def generation_model_no_kv_cache(self, dtype=torch.float32): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.eval() + return model + + @pytest.fixture + def generation_model_batched(self, dtype=torch.float32): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.setup_caches(batch_size=2, dtype=dtype) + model.eval() + return model + + @pytest.fixture + def prompt_tokens(self): + """ + Pytest fixture to create a list of prompt tokens for testing. + """ + return torch.arange(2, 10) + + @pytest.fixture + def prompt_tokens_batched(self): + """ + Pytest fixture to create a list of prompt tokens for testing. + """ + return torch.arange(2, 10).repeat(2, 1) + + def test_sample_consistency(self): + """ + Test token sampling produces the right output. + """ + # set all probabilities except for token_id=100 to 0 + logits = torch.zeros(2000) + logits[100] = 1 + + token = sample(logits, temperature=1, top_k=1) + assert token.item() == 100 + + @pytest.mark.parametrize( + "model1,model2,prompt", + [ + ("generation_model", "generation_model", "prompt_tokens"), + ("generation_model", "generation_model_no_kv_cache", "prompt_tokens"), + ( + "generation_model_batched", + "generation_model_batched", + "prompt_tokens_batched", + ), + ( + "generation_model_batched", + "generation_model_no_kv_cache", + "prompt_tokens_batched", + ), + ], + ) + def test_reproducibility(self, request, model1, model2, prompt): + """ + Test to check if the `generate` function produces the same output when run with the same + inputs and a fixed seed. This should work regardless of batched input or kv cache. + """ + + model1 = request.getfixturevalue(model1) + model2 = request.getfixturevalue(model2) + prompt = request.getfixturevalue(prompt) + + temperature = 0.6 + top_k = 100 + + torch.manual_seed(42) + outputs_first = utils.generate( + model=model1, + prompt=prompt, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.manual_seed(42) + outputs_second = utils.generate( + model=model2, + prompt=prompt, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + assert outputs_first == outputs_second + + def test_batched_generate(self, generation_model_batched, prompt_tokens_batched): + """Test batched generation works as expected.""" + temperature = 0.6 + top_k = 100 + + torch.manual_seed(42) + + output = utils.generate( + model=generation_model_batched, + prompt=prompt_tokens_batched, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + # The numbers here are the first 10 tokens generated by the model + # with constantly initialized weights, a tensor input with range 2 through 10, + # and the manual seed set to 42. They do not correspond to "recognizable" tokens. + expected_output = [ + [ + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 3987, + 3991, + 3953, + 3957, + 3983, + 3964, + 3928, + 3932, + 3986, + 3982, + ], + [ + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 3958, + 3979, + 3934, + 3945, + 3993, + 3904, + 3950, + 3988, + 3948, + 3999, + ], + ] + + assert output == expected_output + + def test_stop_tokens(self, generation_model, prompt_tokens): + """ + Test to check if the `generate` function produces the right output when stop tokens are + provided. + """ + temperature = 0.6 + top_k = 100 + + # This is the first token generated by the model + # so it should stop immediately + stop_tokens = [3987] + + torch.manual_seed(42) + + outputs = utils.generate( + model=generation_model, + prompt=prompt_tokens, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + stop_tokens=stop_tokens, + ) + + expected_output = [[2, 3, 4, 5, 6, 7, 8, 9, 3987]] + + assert outputs == expected_output + + def test_stop_tokens_batched(self, generation_model_batched, prompt_tokens_batched): + """ + Test to check if the `generate` function produces the right output when stop tokens are + provided, but this time in batched format. + """ + temperature = 0.6 + top_k = 100 + + # This is the first token generated by the model + # so it should stop immediately + stop_tokens = [3987, 3958] + + torch.manual_seed(42) + + outputs = utils.generate( + model=generation_model_batched, + prompt=prompt_tokens_batched, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + stop_tokens=stop_tokens, + ) + + expected_output = [ + [2, 3, 4, 5, 6, 7, 8, 9, 3987], + [2, 3, 4, 5, 6, 7, 8, 9, 3958], + ] + + assert outputs == expected_output + + def test_stop_tokens_batched_uneven_stopping( + self, generation_model_batched, prompt_tokens_batched + ): + """ + Test to check if the `generate` function produces the right output when stop tokens are + provided, but this time in batched format. This time, seq 0 should hit a stop token before seq 1. + We expect the output to be the length of seq 1, but the first seq should be truncated. + """ + temperature = 0.6 + top_k = 100 + + # This is the first token generated by the model + # so it should stop immediately + stop_tokens = [3987, 3979] + + torch.manual_seed(42) + + outputs = utils.generate( + model=generation_model_batched, + prompt=prompt_tokens_batched, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + stop_tokens=stop_tokens, + ) + + expected_output = [ + [2, 3, 4, 5, 6, 7, 8, 9, 3987, 0], + [2, 3, 4, 5, 6, 7, 8, 9, 3958, 3979], + ] + + assert outputs == expected_output diff --git a/training/tests/torchtune/utils/test_memory.py b/training/tests/torchtune/utils/test_memory.py new file mode 100644 index 0000000..ba2fbf2 --- /dev/null +++ b/training/tests/torchtune/utils/test_memory.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, +) +from torchtune.utils import set_activation_checkpointing + + +class TestSetActivationCheckpointing: + @pytest.fixture + def model(self) -> int: + return nn.Sequential( + nn.Linear(10, 10), + nn.Linear(10, 10), + nn.Linear(10, 10), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def _verify(self, model): + for submodule in model.modules(): + if isinstance(submodule, CheckpointWrapper): + assert isinstance(submodule._checkpoint_wrapped_module, nn.Linear) + + def test_activation_checkpoint_set_policy(self, model): + set_activation_checkpointing(model=model, auto_wrap_policy={nn.Linear}) + self._verify(model) + + def test_activation_checkpoint_custom_policy(self, model): + def custom_policy(module: nn.Module, recurse: bool, **kwargs) -> bool: + if recurse: + return True + return isinstance(module, nn.Linear) + + set_activation_checkpointing(model=model, auto_wrap_policy=custom_policy) + self._verify(model) diff --git a/training/tests/torchtune/utils/test_metric_logging.py b/training/tests/torchtune/utils/test_metric_logging.py new file mode 100644 index 0000000..4269c9a --- /dev/null +++ b/training/tests/torchtune/utils/test_metric_logging.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import tempfile +from io import StringIO +from typing import cast +from unittest.mock import patch + +import pytest +from omegaconf import OmegaConf +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from tests.test_utils import assert_expected, captured_output + +from torchtune.utils.metric_logging import ( + DiskLogger, + StdoutLogger, + TensorBoardLogger, + WandBLogger, +) + + +class TestDiskLogger: + def test_log(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = DiskLogger(log_dir=log_dir) + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + log_path = logger.path_to_log_file() + assert log_path.exists() + values = open(log_path).readlines() + assert_expected(len(values), 5) + for i in range(5): + assert values[i] == f"Step {i} | test_log:{float(i) ** 2}\n" + + def test_log_dict(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = DiskLogger(log_dir=log_dir) + for i in range(5): + logger.log_dict(step=i, payload={"metric_1": i, "metric_2": i**2}) + logger.close() + + log_path = logger.path_to_log_file() + assert log_path.exists() + values = open(log_path).readlines() + assert_expected(len(values), 5) + for i in range(5): + assert values[i] == f"Step {i} | metric_1:{i} metric_2:{i ** 2} \n" + + +class TestStdoutLogger: + def test_stdout_log(self) -> None: + logger = StdoutLogger() + with captured_output() as (out, _): + logger.log(step=0, name="metric_1", data=1.1) + out = cast(StringIO, out) + assert ( + out.getvalue() == "Step 0 | metric_1:1.1\n" + ), f"Actual output: {out.getvalue()}" + + logger.log(step=1, name="metric_1", data=2.1) + assert ( + out.getvalue() == "Step 0 | metric_1:1.1\nStep 1 | metric_1:2.1\n" + ), f"Actual output: {out.getvalue()}" + + logger.close() + assert ( + out.getvalue() == "Step 0 | metric_1:1.1\nStep 1 | metric_1:2.1\n" + ), f"Actual output: {out.getvalue()}" + + def test_stdout_log_dict(self) -> None: + logger = StdoutLogger() + with captured_output() as (out, _): + logger.log_dict(step=0, payload={"metric_1": 1, "metric_2": 1}) + out = cast(StringIO, out) + assert ( + out.getvalue() == "Step 0 | metric_1:1 metric_2:1 \n" + ), f"Actual output: {out.getvalue()}" + + logger.log_dict( + step=1, payload={"metric_1": 2, "metric_2": 2.2, "metric_3": 2.2344} + ) + assert ( + out.getvalue() + == "Step 0 | metric_1:1 metric_2:1 \nStep 1 | metric_1:2 metric_2:2.2 metric_3:2.2344 \n" + ), f"Actual output: {out.getvalue()}" + + logger.close() + assert ( + out.getvalue() + == "Step 0 | metric_1:1 metric_2:1 \nStep 1 | metric_1:2 metric_2:2.2 metric_3:2.2344 \n" + ), f"Actual output: {out.getvalue()}" + + +class TestTensorBoardLogger: + def test_log(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = TensorBoardLogger(log_dir=log_dir) + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + acc = EventAccumulator(logger.log_dir) + acc.Reload() + for i, event in enumerate(acc.Tensors("test_log")): + assert_expected(event.tensor_proto.float_val[0], float(i) ** 2) + assert_expected(event.step, i) + + def test_log_dict(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = TensorBoardLogger(log_dir=log_dir) + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, 1) + logger.close() + + acc = EventAccumulator(logger.log_dir) + acc.Reload() + for i in range(5): + tensor_tag = acc.Tensors(f"log_dict_{i}")[0] + assert_expected(tensor_tag.tensor_proto.float_val[0], float(i) ** 2) + assert_expected(tensor_tag.step, 1) + + +@pytest.mark.skip(reason="This was never running and needs to be fixed") +class TestWandBLogger: + def test_log(self) -> None: + with patch("wandb.init") as mock_init, patch("wandb.log") as mock_log: + logger = WandBLogger(project="test_project") + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + assert mock_log.call_count == 5 + for i in range(5): + mock_log.assert_any_call({"test_log": float(i) ** 2}, step=i) + + def test_log_dict(self) -> None: + with patch("wandb.init") as mock_init, patch("wandb.log") as mock_log: + logger = WandBLogger(project="test_project") + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, 1) + logger.close() + + mock_log.assert_called_with(metric_dict, step=1) + + def test_save_config(self) -> None: + with patch("wandb.init") as mock_init, patch( + "wandb.run", create=True + ) as mock_run, patch("OmegaConf.save") as mock_save, patch( + "wandb.save" + ) as mock_wandb_save: + + logger = WandBLogger(project="test_project") + cfg = OmegaConf.create({"a": 1, "b": 2}) + + with patch.object(logger, "_wandb", mock_run): + logger.save_config(cfg) + + expected_config_path = "torchtune_config.yaml" + mock_save.assert_called_once_with(cfg, expected_config_path) + mock_wandb_save.assert_called_once_with(expected_config_path) diff --git a/training/tests/torchtune/utils/test_optim_utils.py b/training/tests/torchtune/utils/test_optim_utils.py new file mode 100644 index 0000000..0d491ac --- /dev/null +++ b/training/tests/torchtune/utils/test_optim_utils.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.utils import create_optim_in_bwd_wrapper, register_optim_in_bwd_hooks + + +def _run_dummy_step(model, wrapper): + with torch.no_grad(): + for p in model.parameters(): + p.grad = torch.rand_like(p) + for v in wrapper.optim_map.values(): + v.step() + v.zero_grad() + + +def _validate_dicts(d1, d2): + if len(d1) != len(d2): + return False + for k, v in d1.items(): + if k not in d2: + return False + if isinstance(v, dict): + return _validate_dicts(v, d2[k]) + else: + if isinstance(v, torch.Tensor): + if not torch.allclose(v, d2[k]): + return False + elif v != d2[k]: + return False + return True + + +@pytest.fixture +def model(): + return torch.nn.Linear(10, 1) + + +@pytest.fixture +def optim_dict(model): + return {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()} + + +@pytest.fixture +def wrapper(model, optim_dict): + return create_optim_in_bwd_wrapper(model, optim_dict) + + +class TestOptimInBackward: + def test_state_dict_save_load(self, model, wrapper): + # Run a dummy step to create optimizer states + _run_dummy_step(model, wrapper) + + sd = wrapper.state_dict() + new_optim_dict = create_optim_in_bwd_wrapper( + model, {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()} + ) + assert not _validate_dicts(sd, new_optim_dict.state_dict()) + new_optim_dict.load_state_dict(sd) + assert _validate_dicts(sd, new_optim_dict.state_dict()) + + def test_missing_unexpected_param_load_raises(self, model, wrapper): + # Run a dummy step to create optimizer states + _run_dummy_step(model, wrapper) + sd = wrapper.state_dict() + new_optim_dict = create_optim_in_bwd_wrapper( + model, {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()} + ) + with pytest.raises(RuntimeError, match="Expected to load optimizer state"): + sd.pop(next(iter(sd.keys()))) + new_optim_dict.load_state_dict(sd) + + sd = wrapper.state_dict() + sd["new_key"] = 1234 + with pytest.raises(RuntimeError, match="unexpected param"): + new_optim_dict.load_state_dict(sd) + + +class TestRegisterOptimHooks: + def test_register_optim_in_bwd_hooks(self, model, optim_dict): + register_optim_in_bwd_hooks(model, optim_dict) + # Ensure backward() updates the parameters and sets grads to None + orig_params = [p.clone().detach() for p in model.parameters()] + model(torch.rand(2, 10)).sum().backward() + for p, orig_p in zip(model.parameters(), orig_params): + assert not p.grad + assert not torch.allclose(p, orig_p) diff --git a/training/tests/torchtune/utils/test_pooling.py b/training/tests/torchtune/utils/test_pooling.py new file mode 100644 index 0000000..223bcca --- /dev/null +++ b/training/tests/torchtune/utils/test_pooling.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torchtune.utils.pooling import get_unmasked_sequence_lengths + + +class TestGetLastUnmaskedTokenIdx: + def test_get_last_unmasked_token_idx_multi_batch(self): + """ + Tests that the last non-padding tokens are correctly selected for a multi-batch input. + """ + padding_token_idx = 0 + tokens = torch.tensor([[1, 3, 4, 9], [4, 5, 6, 0], [1, 0, 0, 0], [0, 0, 0, 0]]) + expected_output = torch.tensor([3, 2, 0, 0]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + torch.testing.assert_close(idxs, expected_output) + + def test_get_last_unmasked_token_idx_single_batch(self): + """ + Tests that the last non-padding tokens are correctly selected for a single-batch input. + """ + padding_token_idx = 0 + tokens = torch.tensor([[1, 3, 4, 9, 0]]) + expected_output = torch.tensor([3]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + + torch.testing.assert_close(idxs, expected_output) + + def test_get_last_unmasked_token_idx_multi_batch_all_full(self): + """ + Tests that the last non-padding tokens are correctly selected for multi-batch input, + where none of the sequences have padding tokens. + """ + padding_token_idx = 0 + tokens = torch.tensor( + [[1, 3, 4, 9], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + ) + expected_output = torch.tensor([3, 3, 3, 3]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + + torch.testing.assert_close(idxs, expected_output) + + def test_get_last_unmasked_token_idx_multi_batch_all_empty(self): + """ + Tests that the last non-padding tokens are correctly selected for multi-batch input, + where none of the sequences have any non-padding tokens. + """ + padding_token_idx = 0 + tokens = torch.zeros((4, 4), dtype=torch.long) + expected_output = torch.tensor([0, 0, 0, 0]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + + torch.testing.assert_close(idxs, expected_output) diff --git a/training/tests/torchtune/utils/test_precision.py b/training/tests/torchtune/utils/test_precision.py new file mode 100644 index 0000000..0f5e629 --- /dev/null +++ b/training/tests/torchtune/utils/test_precision.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +from unittest import mock + +import pytest +import torch + +from torchtune.utils.precision import ( + _set_float32_precision, + get_dtype, + PRECISION_STR_TO_DTYPE, + set_default_dtype, + validate_expected_param_dtype, + verify_bf16_support, +) + + +class TestPrecisionUtils: + + cuda_available: bool = torch.cuda.is_available() + + def test_get_dtype(self): + """ + Tests that the correct dtype is returned based on the input string. + """ + dtypes = [None, torch.half] + list(PRECISION_STR_TO_DTYPE.keys()) + expected_dtypes = [ + torch.float32, + torch.float16, + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ] + for dtype, expected_dtype in zip(dtypes, expected_dtypes): + if dtype == "bf16" and not verify_bf16_support(): + continue # skip bf16 tests if not supported. + assert ( + get_dtype(dtype) == expected_dtype + ), f"{dtype} should return {expected_dtype}" + + @mock.patch("torchtune.utils.precision.verify_bf16_support", return_value=False) + def test_error_bf16_unsupported(self, mock_verify): + """ + Tests that an error is raised if bf16 is specified but not supported. + """ + with pytest.raises( + RuntimeError, match="bf16 precision was requested but not available" + ): + get_dtype(torch.bfloat16) + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + def test_set_float32_precision(self) -> None: + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", True + ) + _set_float32_precision("highest") + assert torch.get_float32_matmul_precision() == "highest" + assert not torch.backends.cudnn.allow_tf32 + assert not torch.backends.cuda.matmul.allow_tf32 + + _set_float32_precision("high") + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", False + ) + assert torch.get_float32_matmul_precision() == "high" + assert torch.backends.cudnn.allow_tf32 + assert torch.backends.cuda.matmul.allow_tf32 + + def test_set_default_dtype(self): + dtype = torch.bfloat16 + prev_dtype = torch.get_default_dtype() + with set_default_dtype(dtype): + assert torch.get_default_dtype() == dtype + + assert torch.get_default_dtype() == prev_dtype + + def test_validate_expected_param_dtype(self): + """ + Tests that we raise if any model param has a different dtype than the expected dtype. + """ + m = torch.nn.Linear(10, 10) + with pytest.raises(ValueError, match=f"has dtype {next(m.parameters()).dtype}"): + validate_expected_param_dtype(m.named_parameters(), dtype=torch.float16) diff --git a/training/tests/torchtune/utils/test_seed.py b/training/tests/torchtune/utils/test_seed.py new file mode 100644 index 0000000..55b79de --- /dev/null +++ b/training/tests/torchtune/utils/test_seed.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import os + +import numpy as np +import pytest +import torch +from torchtune.utils.seed import set_seed + + +class TestSeed: + def test_seed_range(self) -> None: + """ + Verify that exceptions are raised on input values + """ + with pytest.raises(ValueError, match="Invalid seed value provided"): + set_seed(-1) + + invalid_max = np.iinfo(np.uint64).max + with pytest.raises(ValueError, match="Invalid seed value provided"): + set_seed(invalid_max) + + # should not raise any exceptions + set_seed(42) + + def test_deterministic_true(self) -> None: + for det_debug_mode, det_debug_mode_str in [(1, "warn"), (2, "error")]: + warn_only = det_debug_mode == 1 + for debug_mode in (det_debug_mode, det_debug_mode_str): + # torch/testing/_internal/common_utils.py calls `disable_global_flags()` + # workaround RuntimeError: not allowed to set ... after disable_global_flags + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", True + ) + set_seed(42, debug_mode=debug_mode) + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", False + ) + assert torch.backends.cudnn.deterministic + assert not torch.backends.cudnn.benchmark + assert det_debug_mode == torch.get_deterministic_debug_mode() + assert torch.are_deterministic_algorithms_enabled() + assert ( + warn_only == torch.is_deterministic_algorithms_warn_only_enabled() + ) + assert os.environ["CUBLAS_WORKSPACE_CONFIG"] == ":4096:8" + + def test_deterministic_false(self) -> None: + for debug_mode in ("default", 0): + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", True + ) + set_seed(42, debug_mode=debug_mode) + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", False + ) + assert not torch.backends.cudnn.deterministic + assert torch.backends.cudnn.benchmark + assert 0 == torch.get_deterministic_debug_mode() + assert not torch.are_deterministic_algorithms_enabled() + assert not torch.is_deterministic_algorithms_warn_only_enabled() + + def test_deterministic_unset(self) -> None: + det = torch.backends.cudnn.deterministic + benchmark = torch.backends.cudnn.benchmark + det_debug_mode = torch.get_deterministic_debug_mode() + det_algo_enabled = torch.are_deterministic_algorithms_enabled() + det_algo_warn_only_enabled = ( + torch.is_deterministic_algorithms_warn_only_enabled() + ) + set_seed(42, debug_mode=None) + assert det == torch.backends.cudnn.deterministic + assert benchmark == torch.backends.cudnn.benchmark + assert det_debug_mode == torch.get_deterministic_debug_mode() + assert det_algo_enabled == torch.are_deterministic_algorithms_enabled() + assert ( + det_algo_warn_only_enabled + == torch.is_deterministic_algorithms_warn_only_enabled() + ) diff --git a/training/torchtune/__init__.py b/training/torchtune/__init__.py new file mode 100644 index 0000000..1c7de70 --- /dev/null +++ b/training/torchtune/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "" + +from torchtune import datasets, models, modules, utils + +__all__ = [datasets, models, modules, utils] diff --git a/training/torchtune/_cli/__init__.py b/training/torchtune/_cli/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/torchtune/_cli/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/torchtune/_cli/cp.py b/training/torchtune/_cli/cp.py new file mode 100644 index 0000000..8c8d466 --- /dev/null +++ b/training/torchtune/_cli/cp.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import shutil +import textwrap +from pathlib import Path + +import torchtune +from torchtune._cli.subcommand import Subcommand +from torchtune._recipe_registry import get_all_recipes + +ROOT = Path(torchtune.__file__).parent.parent + + +class Copy(Subcommand): + """Holds all the logic for the `tune cp` subcommand.""" + + def __init__(self, subparsers): + super().__init__() + self._parser = subparsers.add_parser( + "cp", + prog="tune cp", + usage="tune cp destination [OPTIONS]", + help="Copy a built-in recipe or config to a local path.", + description="Copy a built-in recipe or config to a local path.", + epilog=textwrap.dedent( + """\ + examples: + $ tune cp lora_finetune_distributed . + Copied file to ./lora_finetune_distributed.py + + $ tune cp llama2/7B_full ./new_dir/my_custom_lora.yaml --make-parents + Copyied file to ./new_dir/my_custom_lora.yaml + + Need to see all possible recipes/configs to copy? Try running `tune ls`. + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self._parser.set_defaults(func=self._cp_cmd) + + def _add_arguments(self) -> None: + """Add arguments to the parser.""" + self._parser.add_argument( + "file", + type=str, + help="Recipe/config to copy. For a list of all possible options, run `tune ls`", + ) + self._parser.add_argument( + "destination", + type=Path, + help="Location to copy the file to", + ) + self._parser.add_argument( + "-n", + "--no-clobber", + action="store_true", + help="Do not overwrite destination if it already exists", + default=False, + ) + self._parser.add_argument( + "--make-parents", + action="store_true", + help="Create parent directories for destination if they do not exist. " + "If not set to True, will error if parent directories do not exist", + default=False, + ) + + def _cp_cmd(self, args: argparse.Namespace): + """Copy a recipe or config to a new location.""" + destination: Path = args.destination + src = None + + # Iterate through all recipes and configs + for recipe in get_all_recipes(): + if recipe.name == args.file: + src = ROOT / "recipes" / recipe.file_path + proper_suffix = ".py" + break + for config in recipe.configs: + if config.name == args.file: + src = ROOT / "recipes" / "configs" / config.file_path + proper_suffix = ".yaml" + break + + # Fail if no file exists + if src is None: + self._parser.error( + f"Invalid file name: {args.file}. Try `tune ls` to see all available files to copy." + ) + + # Attach proper suffix if needed + if destination.name != "" and destination.suffix != proper_suffix: + destination = destination.with_suffix(proper_suffix) + + # Copy file + try: + if args.no_clobber and destination.exists(): + print( + f"File already exists at {destination.absolute()}, not overwriting." + ) + else: + if args.make_parents: + destination.parent.mkdir(parents=True, exist_ok=True) + output = shutil.copy(src, destination) + print(f"Copied file to {output}") + except FileNotFoundError: + self._parser.error( + f"Cannot create regular file: '{destination}'. No such file or directory. " + "If the specified destination's parent directory does not exist and you would " + "like to create it on-the-fly, use the --make-parents flag." + ) diff --git a/training/torchtune/_cli/download.py b/training/torchtune/_cli/download.py new file mode 100644 index 0000000..b35b81c --- /dev/null +++ b/training/torchtune/_cli/download.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import textwrap + +from pathlib import Path +from typing import Literal, Union + +from huggingface_hub import snapshot_download +from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError +from torchtune._cli.subcommand import Subcommand + + +class Download(Subcommand): + """Holds all the logic for the `tune download` subcommand.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self._parser = subparsers.add_parser( + "download", + prog="tune download", + usage="tune download [OPTIONS]", + help="Download a model from the Hugging Face Hub.", + description="Download a model from the Hugging Face Hub.", + epilog=textwrap.dedent( + """\ + examples: + # Download a model from the Hugging Face Hub with a Hugging Face API token + $ tune download meta-llama/Llama-2-7b-hf --hf-token + Successfully downloaded model repo and wrote to the following locations: + /tmp/Llama-2-7b-hf/config.json + /tmp/Llama-2-7b-hf/README.md + /tmp/Llama-2-7b-hf/consolidated.00.pth + ... + + # Download an ungated model from the Hugging Face Hub + $ tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/model + Successfully downloaded model repo and wrote to the following locations: + /tmp/model/config.json + /tmp/model/README.md + /tmp/model/model-00001-of-00002.bin + ... + + For a list of all models, visit the Hugging Face Hub https://huggingface.co/models. + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self._parser.set_defaults(func=self._download_cmd) + + def _add_arguments(self) -> None: + """Add arguments to the parser.""" + self._parser.add_argument( + "repo_id", + type=str, + help="Name of the repository on Hugging Face Hub.", + ) + self._parser.add_argument( + "--output-dir", + type=Path, + required=False, + default=None, + help="Directory in which to save the model. Defaults to `/tmp/`.", + ) + self._parser.add_argument( + "--output-dir-use-symlinks", + type=str, + required=False, + default="auto", + help=( + "To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be" + " either duplicated or symlinked to the local directory depending on its size. It set to `True`, a" + " symlink will be created, no matter the file size. If set to `False`, the file will either be" + " duplicated from cache (if already exists) or downloaded from the Hub and not cached." + ), + ) + self._parser.add_argument( + "--hf-token", + type=str, + required=False, + default=os.getenv("HF_TOKEN", None), + help="Hugging Face API token. Needed for gated models like Llama2.", + ) + self._parser.add_argument( + "--ignore-patterns", + type=str, + required=False, + default="*.safetensors", + help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring " + "safetensors files to avoid downloading duplicate weights.", + ) + + def _download_cmd(self, args: argparse.Namespace) -> None: + """Downloads a model from the Hugging Face Hub.""" + # Download the tokenizer and PyTorch model files + + # Default output_dir is `/tmp/` + output_dir = args.output_dir + if output_dir is None: + model_name = args.repo_id.split("/")[-1] + output_dir = Path("/tmp") / model_name + + # Raise if local_dir_use_symlinks is invalid + output_dir_use_symlinks: Union[Literal["auto"], bool] + use_symlinks_lowercase = args.output_dir_use_symlinks.lower() + if use_symlinks_lowercase == "true": + output_dir_use_symlinks = True + elif use_symlinks_lowercase == "false": + output_dir_use_symlinks = False + elif use_symlinks_lowercase == "auto": + output_dir_use_symlinks = "auto" + else: + self._parser.error( + f"'{args.output_dir_use_symlinks}' is not a valid value for `--output-dir-use-symlinks`. It must be either" + " 'auto', 'True' or 'False'." + ) + + print(f"Ignoring files matching the following patterns: {args.ignore_patterns}") + try: + true_output_dir = snapshot_download( + args.repo_id, + local_dir=output_dir, + local_dir_use_symlinks=output_dir_use_symlinks, + ignore_patterns=args.ignore_patterns, + token=args.hf_token, + ) + except GatedRepoError: + self._parser.error( + "It looks like you are trying to access a gated repository. Please ensure you " + "have access to the repository and have provided the proper Hugging Face API token " + "using the option `--hf-token` or by running `huggingface-cli login`." + "You can find your token by visiting https://huggingface.co/settings/tokens" + ) + except RepositoryNotFoundError: + self._parser.error( + f"Repository '{args.repo_id}' not found on the Hugging Face Hub." + ) + except Exception as e: + import traceback + + tb = traceback.format_exc() + msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}" + self._parser.error(msg) + + print( + "Successfully downloaded model repo and wrote to the following locations:", + *list(Path(true_output_dir).iterdir()), + sep="\n", + ) diff --git a/training/torchtune/_cli/ls.py b/training/torchtune/_cli/ls.py new file mode 100644 index 0000000..9f0abdb --- /dev/null +++ b/training/torchtune/_cli/ls.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import textwrap + +from torchtune._cli.subcommand import Subcommand + +from torchtune._recipe_registry import get_all_recipes + + +class List(Subcommand): + """Holds all the logic for the `tune ls` subcommand.""" + + NULL_VALUE = "<>" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self._parser = subparsers.add_parser( + "ls", + prog="tune ls", + help="List all built-in recipes and configs", + description="List all built-in recipes and configs", + epilog=textwrap.dedent( + """\ + examples: + $ tune ls + RECIPE CONFIG + full_finetune_single_device llama2/7B_full_single_device + full_finetune_distributed llama2/7B_full + llama2/13B_full + ... + + To run one of these recipes: + $ tune run full_finetune_single_device --config full_finetune_single_device + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._parser.set_defaults(func=self._ls_cmd) + + def _ls_cmd(self, args: argparse.Namespace) -> None: + """List all available recipes and configs.""" + # Print table header + header = f"{'RECIPE':<40} {'CONFIG':<40}" + print(header) + + # Print recipe/config pairs + for recipe in get_all_recipes(): + # If there are no configs for a recipe, print a blank config + recipe_str = recipe.name + if len(recipe.configs) == 0: + row = f"{recipe_str:<40} {self.NULL_VALUE:<40}" + print(row) + for i, config in enumerate(recipe.configs): + # If there are multiple configs for a single recipe, omit the recipe name + # on latter configs + if i > 0: + recipe_str = "" + row = f"{recipe_str:<40} {config.name:<40}" + print(row) diff --git a/training/torchtune/_cli/run.py b/training/torchtune/_cli/run.py new file mode 100644 index 0000000..1aaee2f --- /dev/null +++ b/training/torchtune/_cli/run.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import runpy +import sys +import textwrap + +from pathlib import Path +from typing import Optional + +import torchtune + +from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run +from torchtune._cli.subcommand import Subcommand +from torchtune._recipe_registry import Config, get_all_recipes, Recipe + +ROOT = Path(torchtune.__file__).parent.parent + + +class Run(Subcommand): + """Holds all the logic for the `tune run` subcommand.""" + + def __init__(self, subparsers): + super().__init__() + self._parser = subparsers.add_parser( + "run", + prog="tune run", + help="Run a recipe. For distributed recipes, this supports all torchrun arguments.", + description="Run a recipe. For distributed recipes, this supports all torchrun arguments.", + usage="tune run [TORCHRUN-OPTIONS] --config [RECIPE-OPTIONS]", + epilog=textwrap.dedent( + """\ + examples: + + # Run a finetuning recipe on a single device w/ default values + $ tune run lora_finetune_single_device --config llama2/7B_lora_single_device + + # Run a finetuning recipe in a distributed fashion using torchrun w/ default values + $ tune run --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full_finetune_distributed + + # Override a parameter in the config file and specify a number of GPUs for torchrun + $ tune run --nproc_per_node 2 \ + lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ + model.lora_rank=16 \ + + Remember, you can use `tune cp` to copy a default recipe/config to your local dir and modify the values. + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self._parser.set_defaults(func=self._run_cmd) + + def _add_arguments(self) -> None: + """Add arguments to the parser. + + This is a bit hacky since we need to add the torchrun arguments to our parser. + This grabs the argparser from torchrun, iterates over it's actions, and adds them + to our parser. We rename the training_script and training_script_args to recipe and recipe_args + respectively. In addition, we leave out the help argument since we add it manually to ours. + """ + torchrun_argparser = get_torchrun_args_parser() + for action in torchrun_argparser._actions: + if action.dest == "training_script": + action.dest = "recipe" + action.help = """Name or path to recipe to be launched followed by args. +For a list of all possible recipes, run `tune ls`.""" + elif action.dest == "training_script_args": + action.dest = "recipe_args" + action.help = "Args to be passed to the recipe." + elif action.dest == "help": + continue + self._parser._add_action(action) + + def _run_distributed(self, args: argparse.Namespace): + """Run a recipe with torchrun.""" + # TODO (rohan-varma): Add check that nproc_per_node <= cuda device count. Currently, + # we don't do this since we test on CPUs for distributed. Will update once multi GPU CI is supported. + print("Running with torchrun...") + # Have to reset the argv so that the recipe can be run with the correct arguments + args.training_script = args.recipe + args.training_script_args = args.recipe_args + run(args) + + def _run_single_device(self, args: argparse.Namespace): + """Run a recipe on a single device.""" + sys.argv = [str(args.recipe)] + args.recipe_args + runpy.run_path(str(args.recipe), run_name="__main__") + + def _is_distributed_args(self, args: argparse.Namespace): + """Check if the user is trying to run a distributed recipe.""" + total = len(sys.argv) - 2 # total args minus "tune run" + script_args = len(args.recipe_args) + 1 # script args + 1 for script name + return total > script_args + + def _get_recipe(self, recipe_str: str) -> Optional[Recipe]: + """Get a recipe from the name or path. + + Args: + recipe_str (str): The name or path of the recipe. + + Returns: + The recipe if it's found in built-in recipes, otherwise None. + """ + for recipe in get_all_recipes(): + if recipe.name == recipe_str: + return recipe + + def _get_config( + self, config_str: str, specific_recipe: Optional[Recipe] + ) -> Optional[Config]: + """Get a config from the name or path. + + Args: + config_str (str): The name or path of the config. + specific_recipe (Optional[Recipe]): The specific recipe to search through. + + Returns: + The config if it's found in built-in configs, otherwise None. + """ + # If a specific recipe is provided, search through it + if specific_recipe is not None: + for config in specific_recipe.configs: + if config.name == config_str: + return config + + # If not, search through all recipes + for recipe in get_all_recipes(): + for config in recipe.configs: + if config.name == config_str: + return config + + def _run_cmd(self, args: argparse.Namespace): + """Run a recipe.""" + # We have to assume that the recipe supports distributed training + supports_distributed = True + recipe_path, config_path = None, None + + # Try to find config string in args + try: + config_idx = args.recipe_args.index("--config") + 1 + config_str = args.recipe_args[config_idx] + except ValueError: + self._parser.error("The '--config' argument is required.") + + # Get recipe path + recipe = self._get_recipe(args.recipe) + if recipe is None: + recipe_path = args.recipe + else: + recipe_path = str(ROOT / "recipes" / recipe.file_path) + supports_distributed = recipe.supports_distributed + + # Get config path + config = self._get_config(config_str, recipe) + if config is None: + config_path = config_str + else: + config_path = str(ROOT / "recipes" / "configs" / config.file_path) + + # Prepare args + args.recipe = recipe_path + args.recipe_args[config_idx] = config_path + + # Execute recipe + if self._is_distributed_args(args): + if not supports_distributed: + self._parser.error( + f"Recipe {recipe.name} does not support distributed training." + "Please run without torchrun commands." + ) + self._run_distributed(args) + else: + self._run_single_device(args) diff --git a/training/torchtune/_cli/subcommand.py b/training/torchtune/_cli/subcommand.py new file mode 100644 index 0000000..db298a0 --- /dev/null +++ b/training/torchtune/_cli/subcommand.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +class Subcommand: + def __init__(self, *args, **kwargs): + pass + + @classmethod + def create(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def _add_arguments(self): + pass diff --git a/training/torchtune/_cli/tune.py b/training/torchtune/_cli/tune.py new file mode 100644 index 0000000..ea9c58b --- /dev/null +++ b/training/torchtune/_cli/tune.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from torchtune._cli.cp import Copy +from torchtune._cli.download import Download +from torchtune._cli.ls import List +from torchtune._cli.run import Run +from torchtune._cli.validate import Validate + + +class TuneCLIParser: + """Holds all information related to running the CLI""" + + def __init__(self): + # Initialize the top-level parser + self._parser = argparse.ArgumentParser( + prog="tune", + description="Welcome to the torchtune CLI!", + add_help=True, + ) + # Default command is to print help + self._parser.set_defaults(func=lambda args: self._parser.print_help()) + + # Add subcommands + subparsers = self._parser.add_subparsers(title="subcommands") + Download.create(subparsers) + List.create(subparsers) + Copy.create(subparsers) + Run.create(subparsers) + Validate.create(subparsers) + + def parse_args(self) -> argparse.Namespace: + """Parse CLI arguments""" + return self._parser.parse_args() + + def run(self, args: argparse.Namespace) -> None: + """Execute CLI""" + args.func(args) + + +def main(): + parser = TuneCLIParser() + args = parser.parse_args() + parser.run(args) + + +if __name__ == "__main__": + main() diff --git a/training/torchtune/_cli/validate.py b/training/torchtune/_cli/validate.py new file mode 100644 index 0000000..7efac13 --- /dev/null +++ b/training/torchtune/_cli/validate.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import textwrap +from pathlib import Path + +from omegaconf import OmegaConf + +from torchtune import config +from torchtune._cli.subcommand import Subcommand +from torchtune.config._errors import ConfigError + + +class Validate(Subcommand): + """Holds all the logic for the `tune validate` subcommand.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self._parser = subparsers.add_parser( + "validate", + prog="tune validate", + help="Validate a config and ensure that it is well-formed.", + description="Validate a config and ensure that it is well-formed.", + usage="tune validate ", + epilog=textwrap.dedent( + """\ + examples: + + $ tune validate recipes/configs/full_finetune_distributed.yaml + Config is well-formed! + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self._parser.set_defaults(func=self._validate_cmd) + + def _add_arguments(self) -> None: + """Add arguments to the parser.""" + self._parser.add_argument( + "config", + type=Path, + help="Path to a config to validate.", + ) + + def _validate_cmd(self, args: argparse.Namespace): + """Validate a config file.""" + cfg = OmegaConf.load(args.config) + + try: + config.validate(cfg) + except ConfigError as e: + self._parser.error(str(e)) + + print("Config is well-formed!") diff --git a/training/torchtune/_recipe_registry.py b/training/torchtune/_recipe_registry.py new file mode 100644 index 0000000..fd53a75 --- /dev/null +++ b/training/torchtune/_recipe_registry.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List + + +@dataclass +class Config: + name: str + file_path: str + + +@dataclass +class Recipe: + name: str + file_path: str + configs: List[Config] + supports_distributed: bool + + +_ALL_RECIPES = [ + Recipe( + name="full_finetune_single_device", + file_path="full_finetune_single_device.py", + configs=[ + Config( + name="llama2/7B_full_low_memory", + file_path="llama2/7B_full_low_memory.yaml", + ), + Config( + name="code_llama2/7B_full_low_memory", + file_path="code_llama2/7B_full_low_memory.yaml", + ), + Config( + name="llama3/8B_full_single_device", + file_path="llama3/8B_full_single_device.yaml", + ), + Config( + name="llama3_1/8B_full_single_device", + file_path="llama3_1/8B_full_single_device.yaml", + ), + Config( + name="mistral/7B_full_low_memory", + file_path="mistral/7B_full_low_memory.yaml", + ), + Config( + name="phi3/mini_full_low_memory", + file_path="phi3/mini_full_low_memory.yaml", + ), + Config( + name="qwen2/7B_full_single_device", + file_path="qwen2/7B_full_single_device.yaml", + ), + Config( + name="qwen2/0.5B_full_single_device", + file_path="qwen2/0.5B_full_single_device.yaml", + ), + Config( + name="qwen2/1.5B_full_single_device", + file_path="qwen2/1.5B_full_single_device.yaml", + ), + ], + supports_distributed=False, + ), + Recipe( + name="full_finetune_distributed", + file_path="full_finetune_distributed.py", + configs=[ + Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), + Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"), + Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"), + Config(name="llama3_1/8B_full", file_path="llama3_1/8B_full.yaml"), + Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"), + Config(name="llama3_1/70B_full", file_path="llama3_1/70B_full.yaml"), + Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"), + Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"), + Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"), + Config(name="phi3/mini_full", file_path="phi3/mini_full.yaml"), + Config(name="qwen2/7B_full", file_path="qwen2/7B_full.yaml"), + Config(name="qwen2/0.5B_full", file_path="qwen2/0.5B_full.yaml"), + Config(name="qwen2/1.5B_full", file_path="qwen2/1.5B_full.yaml"), + ], + supports_distributed=True, + ), + Recipe( + name="lora_finetune_single_device", + file_path="lora_finetune_single_device.py", + configs=[ + Config( + name="llama2/7B_lora_single_device", + file_path="llama2/7B_lora_single_device.yaml", + ), + Config( + name="llama2/7B_qlora_single_device", + file_path="llama2/7B_qlora_single_device.yaml", + ), + Config( + name="code_llama2/7B_lora_single_device", + file_path="code_llama2/7B_lora_single_device.yaml", + ), + Config( + name="code_llama2/7B_qlora_single_device", + file_path="code_llama2/7B_qlora_single_device.yaml", + ), + Config( + name="llama3/8B_lora_single_device", + file_path="llama3/8B_lora_single_device.yaml", + ), + Config( + name="llama3_1/8B_lora_single_device", + file_path="llama3_1/8B_lora_single_device.yaml", + ), + Config( + name="llama3/8B_qlora_single_device", + file_path="llama3/8B_qlora_single_device.yaml", + ), + Config( + name="llama3_1/8B_qlora_single_device", + file_path="llama3_1/8B_qlora_single_device.yaml", + ), + Config( + name="llama2/13B_qlora_single_device", + file_path="llama2/13B_qlora_single_device.yaml", + ), + Config( + name="mistral/7B_lora_single_device", + file_path="mistral/7B_lora_single_device.yaml", + ), + Config( + name="mistral/7B_qlora_single_device", + file_path="mistral/7B_qlora_single_device.yaml", + ), + Config( + name="gemma/2B_lora_single_device", + file_path="gemma/2B_lora_single_device.yaml", + ), + Config( + name="gemma/2B_qlora_single_device", + file_path="gemma/2B_qlora_single_device.yaml", + ), + Config( + name="gemma/7B_lora_single_device", + file_path="gemma/7B_lora_single_device.yaml", + ), + Config( + name="gemma/7B_qlora_single_device", + file_path="gemma/7B_qlora_single_device.yaml", + ), + Config( + name="phi3/mini_lora_single_device", + file_path="phi3/mini_lora_single_device.yaml", + ), + Config( + name="phi3/mini_qlora_single_device", + file_path="phi3/mini_qlora_single_device.yaml", + ), + Config( + name="qwen2/7B_lora_single_device", + file_path="qwen2/7B_lora_single_device.yaml", + ), + Config( + name="qwen2/0.5B_lora_single_device", + file_path="qwen2/0.5B_lora_single_device.yaml", + ), + Config( + name="qwen2/1.5B_lora_single_device", + file_path="qwen2/1.5B_lora_single_device.yaml", + ), + ], + supports_distributed=False, + ), + Recipe( + name="lora_dpo_single_device", + file_path="lora_dpo_single_device.py", + configs=[ + Config( + name="llama2/7B_lora_dpo_single_device", + file_path="llama2/7B_lora_dpo_single_device.yaml", + ), + ], + supports_distributed=False, + ), + Recipe( + name="lora_dpo_distributed", + file_path="lora_dpo_distributed.py", + configs=[ + Config( + name="llama2/7B_lora_dpo", + file_path="llama2/7B_lora_dpo.yaml", + ), + ], + supports_distributed=True, + ), + Recipe( + name="ppo_full_finetune_single_device", + file_path="ppo_full_finetune_single_device.py", + configs=[ + Config( + name="mistral/7B_full_ppo_low_memory", + file_path="mistral/7B_full_ppo_low_memory.yaml", + ), + ], + supports_distributed=False, + ), + Recipe( + name="lora_finetune_distributed", + file_path="lora_finetune_distributed.py", + configs=[ + Config(name="llama2/7B_lora", file_path="llama2/7B_lora.yaml"), + Config(name="llama2/13B_lora", file_path="llama2/13B_lora.yaml"), + Config(name="llama2/70B_lora", file_path="llama2/70B_lora.yaml"), + Config(name="llama3/70B_lora", file_path="llama3/70B_lora.yaml"), + Config(name="llama3_1/70B_lora", file_path="llama3_1/70B_lora.yaml"), + Config(name="llama3/8B_lora", file_path="llama3/8B_lora.yaml"), + Config(name="llama3_1/8B_lora", file_path="llama3_1/8B_lora.yaml"), + Config(name="mistral/7B_lora", file_path="mistral/7B_lora.yaml"), + Config(name="gemma/2B_lora", file_path="gemma/2B_lora.yaml"), + Config(name="gemma/7B_lora", file_path="gemma/7B_lora.yaml"), + Config(name="phi3/mini_lora", file_path="phi3/mini_lora.yaml"), + Config(name="qwen2/7B_lora", file_path="qwen2/7B_lora.yaml"), + Config(name="qwen2/0.5B_lora", file_path="qwen2/0.5B_lora.yaml"), + Config(name="qwen2/1.5B_lora", file_path="qwen2/1.5B_lora.yaml"), + ], + supports_distributed=True, + ), + Recipe( + name="lora_finetune_fsdp2", + file_path="dev/lora_finetune_fsdp2.py", + configs=[ + Config(name="llama2/7B_lora", file_path="dev/llama2/7B_lora_fsdp2.yaml"), + Config(name="llama2/13B_lora", file_path="dev/llama2/13B_lora_fsdp2.yaml"), + Config(name="llama2/70B_lora", file_path="dev/llama2/70B_lora_fsdp2.yaml"), + Config( + name="llama2/7B_qlora", + file_path="dev/llama2/7B_qlora_fsdp2.yaml", + ), + Config( + name="llama2/70B_qlora", + file_path="dev/llama2/70B_qlora_fsdp2.yaml", + ), + ], + supports_distributed=True, + ), + Recipe( + name="generate", + file_path="generate.py", + configs=[ + Config(name="generation", file_path="generation.yaml"), + ], + supports_distributed=False, + ), + Recipe( + name="eleuther_eval", + file_path="eleuther_eval.py", + configs=[ + Config(name="eleuther_evaluation", file_path="eleuther_evaluation.yaml"), + ], + supports_distributed=False, + ), + Recipe( + name="quantize", + file_path="quantize.py", + configs=[ + Config(name="quantization", file_path="quantization.yaml"), + ], + supports_distributed=False, + ), + Recipe( + name="qat_distributed", + file_path="qat_distributed.py", + configs=[ + Config(name="llama2/7B_qat_full", file_path="llama2/7B_qat_full.yaml"), + Config(name="llama3/8B_qat_full", file_path="llama3/8B_qat_full.yaml"), + ], + supports_distributed=True, + ), +] + + +def get_all_recipes(): + """List of recipes available from the CLI.""" + return _ALL_RECIPES diff --git a/training/torchtune/config/__init__.py b/training/torchtune/config/__init__.py new file mode 100644 index 0000000..131caea --- /dev/null +++ b/training/torchtune/config/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._instantiate import instantiate +from ._parse import parse +from ._utils import log_config +from ._validate import validate + +__all__ = [ + "instantiate", + "parse", + "log_config", + "validate", +] diff --git a/training/torchtune/config/_errors.py b/training/torchtune/config/_errors.py new file mode 100644 index 0000000..35ecbab --- /dev/null +++ b/training/torchtune/config/_errors.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + + +class InstantiationError(Exception): + """ + Raised when a `_component_` field in a config is unable to be instantiated. + """ + + pass + + +class ConfigError(Exception): + """ + Raised when the yaml config is not well-formed. Prints all the collected + errors at once. + + Args: + errors (List[Exception]): exceptions found when validating `_component_` + fields in the config + """ + + def __init__(self, errors: List[Exception]): + self.errors = errors + + def __str__(self): + error_messages = [f"{type(e).__name__}: {str(e)}" for e in self.errors] + return "Config is not well-formed, found the following errors: \n" + "\n".join( + error_messages + ) diff --git a/training/torchtune/config/_instantiate.py b/training/torchtune/config/_instantiate.py new file mode 100644 index 0000000..9d17ddd --- /dev/null +++ b/training/torchtune/config/_instantiate.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Any, Callable, Dict, Tuple + +from omegaconf import DictConfig, OmegaConf +from torchtune.config._errors import InstantiationError +from torchtune.config._utils import _get_component_from_path, _has_component + + +def _create_component( + _component_: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +): + return _component_(*args, **kwargs) + + +def _instantiate_node(node: DictConfig, *args: Tuple[Any, ...]): + """ + Creates the object specified in _component_ field with provided positional args + and kwargs already merged. Raises an InstantiationError if _component_ is not specified. + """ + if _has_component(node): + _component_ = _get_component_from_path(node.get("_component_")) + kwargs = {k: v for k, v in node.items() if k != "_component_"} + return _create_component(_component_, args, kwargs) + else: + raise InstantiationError( + "Cannot instantiate specified object." + + "\nMake sure you've specified a _component_ field with a valid dotpath." + ) + + +def instantiate( + config: DictConfig, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Any: + """ + Given a DictConfig with a _component_ field specifying the object to instantiate and + additional fields for keyword arguments, create an instance of the specified object. + You can use this function to create the exact instance of a torchtune object you want + to use in your recipe using the specification from the config. + + This function also supports passing in positional args and keyword args within the + function call. These are automatically merged with the provided config, with keyword + args taking precedence. + + Based on Hydra's `instantiate` utility from Facebook Research: + https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/instantiate/_instantiate2.py#L148 + + Args: + config (DictConfig): a single field in the OmegaConf object parsed from the yaml file. + This is expected to have a _component_ field specifying the path of the object + to instantiate. + *args (Tuple[Any, ...]): positional arguments to pass to the object to instantiate. + **kwargs (Dict[str, Any]): keyword arguments to pass to the object to instantiate. + + Examples: + >>> config.yaml: + >>> model: + >>> _component_: torchtune.models.llama2 + >>> num_layers: 32 + >>> num_heads: 32 + >>> num_kv_heads: 32 + + >>> from torchtune import config + >>> vocab_size = 32000 + >>> # Pass in vocab size as positional argument. Since it is positioned first + >>> # in llama2(), it must be specified first. Pass in other arguments as kwargs. + >>> # This will return an nn.Module directly for llama2 with specified args. + >>> model = config.instantiate(parsed_yaml.model, vocab_size, max_seq_len=4096, embed_dim=4096) + + Returns: + Any: the instantiated object. + + Raises: + ValueError: if config is not a DictConfig. + """ + + # Return None if config is None + if config is None: + return None + if not OmegaConf.is_dict(config): + raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}") + + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + if kwargs: + # This overwrites any repeated fields in the config with kwargs + config = OmegaConf.merge(config, kwargs) + + # Resolve all interpolations, or references to other fields within the same config + OmegaConf.resolve(config) + + return _instantiate_node(config, *args) diff --git a/training/torchtune/config/_parse.py b/training/torchtune/config/_parse.py new file mode 100644 index 0000000..d3875c0 --- /dev/null +++ b/training/torchtune/config/_parse.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import functools +import sys +from typing import Any, Callable + +from omegaconf import DictConfig +from torchtune.config._utils import _merge_yaml_and_cli_args +from torchtune.utils.argparse import TuneRecipeArgumentParser + + +Recipe = Callable[[DictConfig], Any] + + +def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]: + """ + Decorator that handles parsing the config file and CLI overrides + for a recipe. Use it on the recipe's main function. + + Args: + recipe_main (Recipe): The main method that initializes + and runs the recipe + + Examples: + >>> @parse + >>> def main(cfg: DictConfig): + >>> ... + + >>> # With the decorator, the parameters will be parsed into cfg when run as: + >>> tune my_recipe --config config.yaml foo=bar + + Returns: + Callable[[Recipe], Any]: the decorated main + """ + + @functools.wraps(recipe_main) + def wrapper(*args: Any, **kwargs: Any) -> Any: + parser = TuneRecipeArgumentParser( + description=recipe_main.__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + # Get user-specified args from config and CLI and create params for recipe + yaml_args, cli_args = parser.parse_known_args() + conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + + sys.exit(recipe_main(conf)) + + return wrapper diff --git a/training/torchtune/config/_utils.py b/training/torchtune/config/_utils.py new file mode 100644 index 0000000..c25a85f --- /dev/null +++ b/training/torchtune/config/_utils.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace +from importlib import import_module +from types import ModuleType +from typing import Any, Dict, List, Union + +from omegaconf import DictConfig, OmegaConf + +from torchtune.config._errors import InstantiationError +from torchtune.utils import get_logger, get_world_size_and_rank + + +def log_config(recipe_name: str, cfg: DictConfig) -> None: + """ + Logs the resolved config (merged YAML file and CLI overrides) to rank zero. + + Args: + recipe_name (str): name of the recipe to display + cfg (DictConfig): parsed config object + """ + # Log the config only on rank 0 + _, rank = get_world_size_and_rank() + if rank != 0: + return + + logger = get_logger("DEBUG") + cfg_str = OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True) + logger.info(msg=f"Running {recipe_name} with resolved config:\n\n{cfg_str}") + + +def _has_component(node: Union[Dict[str, Any], DictConfig]) -> bool: + return (OmegaConf.is_dict(node) or isinstance(node, dict)) and "_component_" in node + + +def _get_component_from_path(path: str) -> Any: + """ + Return an object by name or dotted path, importing as necessary. + The base functionality relies on ``getattr()`` and handles all + possible exceptions accordingly. + + Based on Hydra's `_locate` from Facebook Research: + https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/utils.py#L614 + + Args: + path (str): Dotted path of the object + + Returns: + Any: The object + + Raises: + InstantiationError: If there is an exception loading the + object from the provided path + ValueError: If a relative or invalid dotpath is passed in + """ + if path == "": + raise ValueError("Empty path") + + parts = [part for part in path.split(".")] + for part in parts: + # If a relative path is passed in, the first part will be empty + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + # First module requires trying to import to validate + part0 = parts[0] + try: + obj = import_module(part0) + except ImportError as exc_import: + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that module '{part0}' is installed?" + ) from exc_import + # Subsequent components can be checked via getattr() on first module + # It can either be an attribute that we can return or a submodule that we + # can import and continue searching + for m in range(1, len(parts)): + part = parts[m] + try: + obj = getattr(obj, part) + # If getattr fails, check to see if it's a module we can import and + # continue down the path + except AttributeError as exc_attr: + parent_dotpath = ".".join(parts[:m]) + if isinstance(obj, ModuleType): + mod = ".".join(parts[: m + 1]) + try: + obj = import_module(mod) + continue + except ModuleNotFoundError as exc_import: + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" + ) from exc_import + # Any other error trying to import module can be raised as + # InstantiationError + except Exception as exc_import: + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_import)}" + ) from exc_import + # If the component is not an attribute nor a module, it doesn't exist + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" + ) from exc_attr + return obj + + +def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictConfig: + """ + Takes the direct output of argparse's parse_known_args which returns known + args as a Namespace and unknown args as a dotlist (in our case, yaml args and + cli args, respectively) and merges them into a single OmegaConf DictConfig. + + If a cli arg overrides a yaml arg with a _component_ field, the cli arg can + be specified with the parent field directly, e.g., model=torchtune.models.lora_llama2_7b + instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the + component should be specified with dot notation, e.g., model.lora_rank=16. + + Example: + >>> config.yaml: + >>> a: 1 + >>> b: + >>> _component_: torchtune.models.my_model + >>> c: 3 + + >>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4 + >>> yaml_args, cli_args = parser.parse_known_args() + >>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + >>> print(conf) + >>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}} + + Args: + yaml_args (Namespace): Namespace containing args from yaml file, components + should have _component_ fields + cli_args (List[str]): List of key=value strings + + Returns: + DictConfig: OmegaConf DictConfig containing merged args + + Raises: + ValueError: If a cli override is not in the form of key=value + """ + # Convert Namespace to simple dict + yaml_kwargs = vars(yaml_args) + cli_dotlist = [] + for arg in cli_args: + # If CLI override uses the remove flag (~), remove the key from the yaml config + if arg.startswith("~"): + dotpath = arg[1:].split("=")[0] + if "_component_" in dotpath: + raise ValueError( + f"Removing components from CLI is not supported: ~{dotpath}" + ) + try: + _remove_key_by_dotpath(yaml_kwargs, dotpath) + except (KeyError, ValueError): + raise ValueError( + f"Could not find key {dotpath} in yaml config to remove" + ) from None + continue + # Get other overrides that should be specified as key=value + try: + k, v = arg.split("=") + except ValueError: + raise ValueError( + f"Command-line overrides must be in the form of key=value, got {arg}" + ) from None + # If a cli arg overrides a yaml arg with a _component_ field, update the + # key string to reflect this + if k in yaml_kwargs and _has_component(yaml_kwargs[k]): + k += "._component_" + cli_dotlist.append(f"{k}={v}") + + # Merge the args + cli_conf = OmegaConf.from_dotlist(cli_dotlist) + yaml_conf = OmegaConf.create(yaml_kwargs) + + # CLI takes precedence over yaml args + return OmegaConf.merge(yaml_conf, cli_conf) + + +def _remove_key_by_dotpath(nested_dict: Dict[str, Any], dotpath: str) -> None: + """ + Removes a key specified by dotpath from a nested dict. Errors should handled by + the calling function. + + Args: + nested_dict (Dict[str, Any]): Dict to remove key from + dotpath (str): dotpath of key to remove, e.g., "a.b.c" + """ + path = dotpath.split(".") + + def delete_non_component(d: Dict[str, Any], key: str) -> None: + if _has_component(d[key]): + raise ValueError( + f"Removing components from CLI is not supported: ~{dotpath}" + ) + del d[key] + + def recurse_and_delete(d: Dict[str, Any], path: List[str]) -> None: + if len(path) == 1: + delete_non_component(d, path[0]) + else: + recurse_and_delete(d[path[0]], path[1:]) + if not d[path[0]]: + delete_non_component(d, path[0]) + + recurse_and_delete(nested_dict, path) diff --git a/training/torchtune/config/_validate.py b/training/torchtune/config/_validate.py new file mode 100644 index 0000000..6d2943e --- /dev/null +++ b/training/torchtune/config/_validate.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect + +from omegaconf import DictConfig +from torchtune.config._errors import ConfigError +from torchtune.config._utils import _get_component_from_path, _has_component + + +def validate(cfg: DictConfig) -> None: + """ + Ensure that all components in the config can be instantiated correctly + + Args: + cfg (DictConfig): The config to validate + + Raises: + ConfigError: If any component cannot be instantiated + """ + + errors = [] + for node, nodedict in cfg.items(): + if _has_component(nodedict): + try: + _component_ = _get_component_from_path(nodedict.get("_component_")) + kwargs = {k: v for k, v in nodedict.items() if k != "_component_"} + sig = inspect.signature(_component_) + sig.bind(**kwargs) + # Some objects require other objects as arguments, like optimizers, + # lr_schedulers, datasets, etc. Try doing partial instantiation + except TypeError as e: + if "missing a required argument" in str(e): + sig.bind_partial(**kwargs) + else: + # inspect.signature does not retain the function name in the + # exception, so we manually add it back in + e = TypeError(f"{_component_.__name__} {str(e)}") + errors.append(e) + + if errors: + raise ConfigError(errors) diff --git a/training/torchtune/data/__init__.py b/training/torchtune/data/__init__.py new file mode 100644 index 0000000..c740a6d --- /dev/null +++ b/training/torchtune/data/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data._chat_formats import ( + ChatFormat, + ChatMLFormat, + Llama2ChatFormat, + MistralChatFormat, + Llama3ChatFormat +) +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.data._converters import get_openai_messages, get_sharegpt_messages +from torchtune.data._instruct_templates import ( + AlpacaInstructTemplate, + InstructTemplate, + StackExchangedPairedTemplate, +) +from torchtune.data._messages import ( + InputOutputToMessages, + JSONToMessages, + Message, + Role, + ShareGPTToMessages, +) +from torchtune.data._prompt_templates import ( + ChatMLTemplate, + GrammarErrorCorrectionTemplate, + PromptTemplate, + PromptTemplateInterface, + SummarizeTemplate, +) +from torchtune.data._utils import truncate, validate_messages + +__all__ = [ + "AlpacaInstructTemplate", + "ChatFormat", + "CROSS_ENTROPY_IGNORE_IDX", + "GrammarErrorCorrectionTemplate", + "InstructTemplate", + "SummarizeTemplate", + "Llama2ChatFormat", + "MistralChatFormat", + "ChatMLFormat", + "JSONToMessages", + "ShareGPTToMessages", + "truncate", + "Message", + "validate_messages", + "StackExchangedPairedTemplate", + "Role", + "PromptTemplateInterface", + "PromptTemplate", + "InputOutputToMessages", + "ChatMLTemplate", + "get_openai_messages", + "get_sharegpt_messages", +] diff --git a/training/torchtune/data/_chat_formats.py b/training/torchtune/data/_chat_formats.py new file mode 100644 index 0000000..8e696b1 --- /dev/null +++ b/training/torchtune/data/_chat_formats.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + +from torchtune.data._messages import Message, Role + + +class ChatFormat(ABC): + """ + Interface for chat formats. Each chat format should include tags for system, + user, and assistant roles that are prepended or appended to the message + content. + """ + + # Template should map role to a tuple containing the tag to prepend to the text + # and tag to append to the text. Leave as empty strings to not prepend or append + template: Dict[Role, Tuple[str, str]] + + @classmethod + @abstractmethod + def format( + cls, + sample: List[Message], + ) -> List[Message]: + """ + Format each role's message(s) according to the chat format + + Args: + sample (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + """ + pass + + + +class Llama3ChatFormat(ChatFormat): + @classmethod + def format(cls, sample: List[Message]) -> List[Message]: + # bp() + return sample + + +class Llama2ChatFormat(ChatFormat): + """ + Chat format that formats human and system prompts with appropriate tags + used in Llama2 pre-training. Taken from Meta's official `Llama inference + repository `_. + + .. code-block:: text + + "[INST] <> + You are a helpful, respectful and honest assistant. + <>" + + I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture..." + + + """ + + template = { + "system": ("<>\n", "\n<>\n\n"), + "user": ("[INST] ", " [/INST] "), + "assistant": ("", ""), + "ipython": ("", ""), + } + + @classmethod + def format( + cls, + sample: List[Message], + ) -> List[Message]: + """ + Format user and system messages with appropriate tags. + + Args: + sample (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + """ + system_message = [] + formatted_dialogue = [] + for message in sample: + if message.role == "system": + system_message = ( + [{"type": "text", "content": cls.template["system"][0]}] + + message.content + + [{"type": "text", "content": cls.template["system"][1]}] + ) + # Incorporate the system message in the user message - Llama2 only + # looks for the <> tags and not the explicit role so this will + # be treated the same as an actual system message. We do this because + # of the nesting of the system prompt in the user message. + continue + elif message.role == "user": + content = ( + [{"type": "text", "content": cls.template["user"][0]}] + + system_message + + message.content + + [{"type": "text", "content": cls.template["user"][1]}] + ) + elif message.role == "assistant": + # No special formatting needed for assistant message + content = message.content + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue + + +class MistralChatFormat(ChatFormat): + """ + Formats according to `Mistral's instruct model `_. + + It is identical to :class:`Llama2ChatFormat`, except it does not support system + prompts. + + .. code-block:: text + + "[INST] I am going to Paris, what should I see? [/INST] Paris, the capital + of France, is known for its stunning architecture..." + + """ + + template = { + "system": None, + "user": ("[INST] ", " [/INST] "), + "assistant": ("", ""), + "ipython": ("", ""), + } + + @classmethod + def format( + cls, + sample: List[Message], + ) -> List[Message]: + """ + Format user and system messages with appropriate tags. + + Args: + sample (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + + Raises: + ValueError: If system prompts are provided + """ + formatted_dialogue = [] + for message in sample: + if message.role == "system": + raise ValueError( + "System prompts are not supported in MistralChatFormat" + ) + else: + content = ( + [{"type": "text", "content": cls.template[message.role][0]}] + + message.content + + [{"type": "text", "content": cls.template[message.role][1]}] + ) + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue + + +class ChatMLFormat(ChatFormat): + """ + OpenAI's `Chat Markup Language + `_ + used by their chat models. + + It is the default chat format used by Hugging Face models. + + .. code-block:: text + + <|im_start|>system + Provide some context and/or instructions to the model.<|im_end|> + <|im_start|>user + The user’s message goes here<|im_end|> + <|im_start|>assistant + The assistant’s response goes here<|im_end|> + + """ + + template = { + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>"), + "ipython": ("", ""), + } + + @classmethod + def format( + cls, + sample: List[Message], + ) -> List[Message]: + """ + Format user and system messages with appropriate tags. + + Args: + sample (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + """ + formatted_dialogue = [] + for message in sample: + content = ( + [{"type": "text", "content": cls.template[message.role][0]}] + + message.content + + [{"type": "text", "content": cls.template[message.role][1]}] + ) + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue diff --git a/training/torchtune/data/_common.py b/training/torchtune/data/_common.py new file mode 100644 index 0000000..3f8c460 --- /dev/null +++ b/training/torchtune/data/_common.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CROSS_ENTROPY_IGNORE_IDX = -100 diff --git a/training/torchtune/data/_converters.py b/training/torchtune/data/_converters.py new file mode 100644 index 0000000..160aa2a --- /dev/null +++ b/training/torchtune/data/_converters.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping + +from torchtune.data._messages import Message + + +def get_sharegpt_messages( + sample: Mapping[str, Any], train_on_input: bool = False +) -> List[Message]: + """ + Convert a chat sample adhering to the ShareGPT json structure to torchtune's :class:`~torchtune.data.Message` + structure. + + ShareGPT follows:: + + { + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + + :class:`~torchtune.data.Message` follows:: + + [ + { + "role": , + "content": , + }, + ... + ] + + Args: + sample (Mapping[str, Any]): a single data sample with "conversations" field pointing + to a list of dict messages. + train_on_input (bool): whether the prompt should remain unmasked. Default: False + + Returns: + List[Message]: A list of messages with "role" and "content" fields. + """ + role_map = {"system": "system", "human": "user", "gpt": "assistant"} + conversations = sample["conversations"] + + messages = [] + for message in conversations: + role = role_map[message["from"]] + content = message["value"] + masked = (role != "assistant") and (not train_on_input) + messages.append( + Message( + role=role, content=[{"type": "text", "content": content}], masked=masked + ) + ) + return messages + + +def get_openai_messages( + sample: Mapping[str, Any], + train_on_input: bool = False, +) -> List[Message]: + """ + Convert a chat sample adhering to the OpenAI API json structure to torchtune's :class:`~torchtune.data.Message` + structure. + + OpenAI API `standard chat format `_ follows:: + + { + # key could be "messages" OR "conversations" + "messages": [ + { + "role": , + "content": , + }, + ... + ] + } + + :class:`~torchtune.data.Message` follows:: + + [ + { + "role": , + "content": , + }, + ... + ] + + Args: + sample (Mapping[str, Any]): a single data sample with "conversations" field pointing + to a list of dict messages. + train_on_input (bool): whether the prompt should remain unmasked. Default: False + + Raises: + ValueError: If the sample does not contain "messages" or "conversations" key. + + Returns: + List[Message]: A list of messages with "role" and "content" fields. + """ + if "messages" in sample: + messages_key = "messages" + elif "conversations" in sample: + messages_key = "conversations" + else: + raise ValueError( + f"Sample does not contain 'messages' or 'conversations' key. Existing keys: {sample.keys()}" + ) + conversations = sample[messages_key] + + messages = [] + for message in conversations: + message["masked"] = (message["role"] != "assistant") and (not train_on_input) + messages.append(Message.from_dict(message)) + return messages diff --git a/training/torchtune/data/_instruct_templates.py b/training/torchtune/data/_instruct_templates.py new file mode 100644 index 0000000..49f52c2 --- /dev/null +++ b/training/torchtune/data/_instruct_templates.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Mapping, Optional + + +class InstructTemplate(ABC): + """ + Interface for instruction templates. Each template should include the template + prompt with placeholders for the data inputs. + """ + + template = "" + + @classmethod + @abstractmethod + def format( + cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None + ) -> str: + """ + Format the prompt template with the given arguments. + + Args: + sample (Mapping[str, Any]): a single data sample with various fields + column_map (Optional[Dict[str, str]]): a mapping from the expected + placeholder names in the template to the column names in the sample. + If None, assume these are identical. Note: if the sample output is not named + as "output" in the dataset, you always need to map it to "output" in column_map. + + Returns: + The formatted prompt + """ + pass + + +class AlpacaInstructTemplate(InstructTemplate): + """ + Prompt template for Alpaca-style datasets. Template prompt changes slightly depending + on if there's an instruction + input or just an instruction. + + .. code-block:: text + + Below is an instruction that describes a task, paired with an input that provides further context. + Write a response that appropriately completes the request. + + ### Instruction: + + + ### Input: + + + ### Response: + + + Or without 'input' + + .. code-block:: text + + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + + + ### Response: + + + """ + + template = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:\n" + ), + } + + @classmethod + def format( + cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None + ) -> str: + """ + Generate prompt from instruction and input. + + Args: + sample (Mapping[str, Any]): a single data sample with instruction + column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names + in the template to the column names in the sample. If None, assume these are identical. + + Examples: + >>> # Simple instruction + >>> AlpacaInstructTemplate.format(sample={"instruction": "Write a poem"}) + Below is an instruction that describes a task, paired with an input that provides further context. + Write a response that appropriately completes the request.\\n\\n### Instruction:\\nWrite a poem\\n\\n### Response:\\n + + >>> # Instruction with input + >>> AlpacaInstructTemplate.format(sample={"instruction": "Write a poem", "input": "The poem should be 5 lines long"}) + Below is an instruction that describes a task, paired with an input that provides further context. + Write a response that appropriately completes the request.\\n\\n### Instruction:\\nWrite a poem\\n\\n### Input:\\n + The poem should be 5 lines long\\n\\n### Response:\\n + + >>> # Instruction with column map where the 'instruction' key is actually named 'prompt' in the given sample + >>> AlpacaInstructTemplate.format(sample={"prompt": "Write me a poem"}, column_map={"instruction": "prompt"}) + Below is an instruction that describes a task, paired with an input that provides further context. + Write a response that appropriately completes the request.\\n\\n### Instruction:\\nWrite a poem\\n\\n### Response:\\n + + Returns: + The formatted prompt + """ + column_map = column_map or {} + key_input = column_map.get("input", "input") + key_instruction = column_map.get("instruction", "instruction") + + if key_input in sample and sample[key_input]: + prompt = cls.template["prompt_input"].format( + instruction=sample[key_instruction], input=sample[key_input] + ) + else: + prompt = cls.template["prompt_no_input"].format( + instruction=sample[key_instruction] + ) + return prompt + + +class StackExchangedPairedTemplate(InstructTemplate): + """ + Prompt template for preference datasets similar to StackExchangedPaired. + + .. code-block:: text + + Question: + + Answer: + """ + + template = "Question: {question}\n\nAnswer: " + + @classmethod + def format( + cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None + ) -> str: + """ + Generate prompt from instruction and input. + + Args: + sample (Mapping[str, Any]): a single data sample with instruction + column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names + in the template to the column names in the sample. If None, assume these are identical. + + Examples: + >>> # Simple question + >>> StackExchangedPairedTemplate.format(sample={"question": "What is the capital of France?"}) + Question: What is the capital of France?\\n\\nAnswer: + + >>> # Question with column map where the 'question' key is actually named 'prompt' in the given sample + >>> StackExchangedPairedTemplate.format( + ... sample={"prompt": "What is the capital of France?"}, + ... column_map={"question": "prompt"} + ... ) + Question: What is the capital of France?\\n\\nAnswer: + + Returns: + The formatted prompt + """ + column_map = column_map or {} + key_prompt = column_map.get("prompt", "prompt") + prompt = cls.template.format(question=sample[key_prompt]) + + return prompt diff --git a/training/torchtune/data/_messages.py b/training/torchtune/data/_messages.py new file mode 100644 index 0000000..5a34fc3 --- /dev/null +++ b/training/torchtune/data/_messages.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Literal, Mapping, Optional, Union + +from torchtune.modules.transforms import Transform + +Role = Literal[ + "system", # Origin is system prompt + "user", # Origin is user + "assistant", # Origin is the model output + "ipython", # Origin is return from a tool call +] + + +class Message: + """ + This class represents individual messages in a fine-tuning dataset. It supports + text-only content, text with interleaved images, and tool calls. The :class:`~torchtune.modules.tokenizers.ModelTokenizer` + will tokenize the content of the message using ``tokenize_messages`` and attach + the appropriate special tokens based on the flags set in this class. + + Args: + role (Role): role of the message writer. Can be "system" for system prompts, + "user" for human prompts, "assistant" for model responses, or "ipython" + for tool call returns. + content (Union[str, List[Dict[str, str]]]): content of the message. If it is text only content, + you can pass in a string. If it is multimodal content, pass in a list of dictionaries formatted + as follows:: + + [ + {"type": "image"} + {"type": "text", "content": "hello"}, + {"type": "image"} + {"type": "text", "content": "world"}, + ] + + masked (bool): whether the message is masked in the sample. If True, do not use + in loss calculation. Default: False + ipython (bool): whether the message is a tool call. Default: False + eot (bool): whether the message corresponds to the end of a turn. Should be true + except in the case of multiple consecutive assistant messages (i.e., tool calls + by assistant). Default: True + """ + + def __init__( + self, + role: Role, + content: Union[str, List[Dict[str, str]]], + masked: bool = False, + ipython: bool = False, + eot: bool = True, + ): + self.role = role + self.content = ( + [{"type": "text", "content": content}] + if isinstance(content, str) + else content + ) + self.masked = masked + self.ipython = ipython + self.eot = eot + + self._validate_message() + + @classmethod + def from_dict(cls, d: dict) -> "Message": + """ + Construct a Message from a dictionary. + + Args: + d (dict): dictionary containing the fields of the Message. + + Returns: + Message: constructed Message. + """ + return cls( + role=d["role"], + content=d["content"], + masked=d.get("masked", False), + ipython=d.get("ipython", False), + eot=d.get("eot", True), + ) + + @property + def contains_media(self) -> bool: + """ + Returns True if message contains non-text content. + """ + return any(content["type"] != "text" for content in self.content) + + @property + def text_content(self) -> str: + """ + Returns text-only content of the message. + """ + return "".join( + content["content"] for content in self.content if content["type"] == "text" + ) + + def _validate_message(self) -> None: + if self.ipython and self.contains_media: + raise ValueError( + f"Media tokens in tool calls are not supported. Both are set in message: {self.text_content}" + ) + if self.ipython and self.role != "assistant": + raise ValueError( + f"Only assistant messages can be tool calls. Found role {self.role} in message: {self.text_content}" + ) + + +class InputOutputToMessages(Transform): + """ + Message transform class that converts a sample with "input" and "output" fields, + (or equivalent fields specified in column_map) to user and assistant messages, + respectively. This is useful for datasets that have two columns, one containing + the user prompt and the other containing the model response. + + Args: + train_on_input (bool): Whether the model is trained on the user prompt or not. + Default is False. + column_map (Optional[Dict[str, str]]): a mapping to change the expected "input" + and "output" column names to the actual column names in the dataset. Default is None, + keeping the default "input" and "output" column names. + """ + + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self.column_map = column_map + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + column_map = self.column_map or {} + key_input = column_map.get("input", "input") + key_output = column_map.get("output", "output") + messages = [ + Message( + role="user", + content=sample[key_input], + masked=not self.train_on_input, + eot=False, + ), + Message( + role="assistant", + content=sample[key_output], + masked=False, + eot=True, + ), + ] + return {"messages": messages} + + +class ShareGPTToMessages(Transform): + """ + Convert a chat sample adhering to the ShareGPT json structure to torchtune's :class:`~torchtune.data.Message` + structure. + + ShareGPT follows:: + + { + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + + :class:`~torchtune.data.Message` follows:: + + [ + { + "role": , + "content": , + }, + ... + ] + + Args: + train_on_input (bool): whether the prompt should remain unmasked. Default: False + column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("conversations") + to the new column names in the dataset. If None, assume these are identical. + Default is None. + """ + + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self.column_map = column_map + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Return a list of Message objects from the provided sample dict. + + Args: + sample (Mapping[str, Any]): a single data sample with "messages" field pointing + to a list of dict messages. + + Returns: + List[Message]: A list of messages with "role" and "content" fields. + """ + role_map = {"system": "system", "human": "user", "gpt": "assistant"} + + messages = [] + for message in sample["conversations"]: + role = role_map[message["from"]] + content = message["value"] + masked = (role != "assistant") and (not self.train_on_input) + messages.append(Message(role=role, content=content, masked=masked)) + + return {"messages": messages} + + +class JSONToMessages(Transform): + """ + Convert a chat sample with identical json structure to torchtune's :class:`~torchtune.data.Message` + structure. This transform simply creates Message dataclasses from the provided jsons. + + For example:: + + { + "messages": [ + { + "role": , + "content": , + }, + ... + ] + } + + :class:`~torchtune.data.Message` follows:: + + [ + { + "role": , + "content": , + }, + ... + ] + + Args: + train_on_input (bool): whether the prompt should remain unmasked. Default: False + column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("messages") + to the new column names in the dataset. If None, assume these are identical. + Default is None. + """ + + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self.column_map = column_map + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Return a list of Message objects from the provided sample dict. + + Args: + sample (Mapping[str, Any]): a single data sample with "messages" field pointing + to a list of dict messages. + + Returns: + List[Message]: A list of messages with "role" and "content" fields. + """ + updated_messages = [] + for message in sample["messages"]: + message["masked"] = (message["role"] != "assistant") and ( + not self.train_on_input + ) + updated_messages.append(Message.from_dict(message)) + + return {"messages": updated_messages} diff --git a/training/torchtune/data/_prompt_templates.py b/training/torchtune/data/_prompt_templates.py new file mode 100644 index 0000000..20a2211 --- /dev/null +++ b/training/torchtune/data/_prompt_templates.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from functools import partial +from typing import Dict, List, Protocol, Tuple + +from torchtune.data import Message, Role + + +class PromptTemplateInterface(Protocol): + """ + Interface for prompt templates. Each prompt template can include structured + text for system, user, and assistant roles that are prepended or appended to + the message content. + """ + + # Template should map role to a tuple containing the tag to prepend to the text + # and tag to append to the text. Leave as empty strings to not prepend or append + template: Dict[Role, Tuple[str, str]] + + def __call__( + self, + messages: List[Message], + ) -> List[Message]: + """ + Format each role's message(s) according to the prompt template + + Args: + messages (List[Message]): a single conversation, structured as a list + of :class:`~torchtune.data.Message` objects + + Returns: + The formatted list of messages + """ + pass + + +class PromptTemplate(PromptTemplateInterface): + """ + Quickly define a custom prompt template by passing in a dictionary mapping role to + the prepend and append tags. For example, to achieve the following prompt + template:: + + System: {content}\\n + User: {content}\\n + Assistant: {content}\\n + Tool: {content}\\n + + You need to pass in a tuple for each role, where ``PREPEND_TAG`` is the string + added before the text content and ``APPEND_TAG`` is the string added after:: + + template = {role: (PREPEND_TAG, APPEND_TAG)} + + Thus, the template would be defined as follows:: + + template = { + "system": ("System: ", "\\n"), + "user": ("User: ", "\\n"), + "assistant": ("Assistant: ", "\\n"), + "ipython": ("Tool: ", "\\n"), + } + + Once instantiated, you must call the prompt template on a list of messages. It + will return the same list of messages updated with the template. + + Note: + Any tags prepended/appended to the assistant message will be included + in the loss calculation. All other prepend/append tags for other roles + (system, user, ipython) are, in most cases, not included in loss. Consider using + the append tags for user messages for tags that need to come before the + assistant message but should not be included in loss. For more custom masking + and prompt templating, you can create your own class based off the + :class:`~torchtune.data.PromptTemplate` interface. + + Args: + template (Dict[Role, Tuple[str, str]]): a dictionary mapping role to the + prepend and append tags + """ + + def __init__( + self, + template: Dict[Role, Tuple[str, str]], + ): + self.template = template + + def __call__(self, messages: List[Message]) -> List[Message]: + """ + Format each role's message(s) according to the prompt template by prepending + and appending the defined tags. + + Args: + messages (List[Message]): list of messages to apply the template to + + Returns: + List[Message]: The formatted list of messages + """ + formatted_dialogue = [] + for message in messages: + if message.role in self.template: + prepend_tag = self.template[message.role][0] + append_tag = self.template[message.role][1] + content = ( + [{"type": "text", "content": prepend_tag}] + + message.content + + [{"type": "text", "content": append_tag}] + ) + else: + content = message.content + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue + + +class ChatMLTemplate(PromptTemplateInterface): + """ + OpenAI's `Chat Markup Language + `_ + used by their chat models. + + It is the default chat template used by Hugging Face models. + + .. code-block:: text + + <|im_start|>system + Provide some context and/or instructions to the model.<|im_end|> + <|im_start|>user + The user’s message goes here<|im_end|> + <|im_start|>assistant + The assistant’s response goes here<|im_end|> + + """ + + template = { + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>"), + "ipython": ("", ""), + } + + def __call__( + self, + messages: List[Message], + ) -> List[Message]: + """ + Format user, assistant, and system messages with appropriate tags. + + Args: + messages (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + """ + formatted_dialogue = [] + for message in messages: + content = ( + [{"type": "text", "content": self.template[message.role][0]}] + + message.content + + [{"type": "text", "content": self.template[message.role][1]}] + ) + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue + + +GrammarErrorCorrectionTemplate = partial( + PromptTemplate, + template={ + "user": ("Correct this to standard English: ", "\n---\nCorrected: "), + }, +) +GrammarErrorCorrectionTemplate.__doc__ = """ +A prompt template for grammar error correction tasks:: + + Correct this to standard English: {user_message} + --- + Corrected: {assistant_message} + +Please see :class:`~torchtune.data.PromptTemplate` for full API arguments. +""" +SummarizeTemplate = partial( + PromptTemplate, + template={ + "user": ("Summarize this dialogue:\n", "\n---\nSummary:\n"), + }, +) +SummarizeTemplate.__doc__ = """ +A prompt template for summarization tasks:: + + Summarize this dialogue: + {user_message} + --- + Summary: + {assistant_message} + +Please see :class:`~torchtune.data.PromptTemplate` for full API arguments. +""" diff --git a/training/torchtune/data/_utils.py b/training/torchtune/data/_utils.py new file mode 100644 index 0000000..36dd187 --- /dev/null +++ b/training/torchtune/data/_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional + +from torchtune.data._messages import Message + + +def truncate( + tokens: List[Any], + max_seq_len: int, + eos_id: Optional[Any] = None, +) -> List[Any]: + """ + Truncate a list of tokens to a maximum length. If eos_id is provided, the last + token will be replaced with eos_id. + + Args: + tokens (List[Any]): list of tokens to truncate + max_seq_len (int): maximum length of the list + eos_id (Optional[Any]): token to replace the last token with. If None, the + last token will not be replaced. Default is None. + + Returns: + List[Any]: truncated list of tokens + """ + tokens_truncated = tokens[:max_seq_len] + if eos_id is not None and tokens_truncated[-1] != eos_id: + tokens_truncated[-1] = eos_id + return tokens_truncated + + +def validate_messages( + messages: List[Message], +) -> None: + """ + Given a list of messages, ensure that messages form a valid + back-and-forth conversation. An error will be raised if: + + - There is a system message that's not the first message + - There are two consecutive user messages + - An assistant message comes before the first user message + - The message is empty + - Messages are shorter than length of 2 (min. one user-assistant turn) + + + Args: + messages (List[Message]): the messages to validate. + + Raises: + ValueError: If the messages are invalid. + """ + if len(messages) < 2: + raise ValueError( + f"Messages must be at least length 2, but got {len(messages)} messages" + ) + + last_turn = "assistant" + for i, message in enumerate(messages): + if message.role == "assistant" and last_turn != "user": + raise ValueError( + f"Assistant message before expected user message at index {i} in messages" + ) + if message.role == "user" and last_turn == "user": + raise ValueError( + f"Two consecutive user messages at index {i} and {i - 1} in messages" + ) + if message.role == "system" and i > 0: + raise ValueError( + f"System message at index {i} in messages, but system messages must come first" + ) + last_turn = message.role diff --git a/training/torchtune/datasets/__init__.py b/training/torchtune/datasets/__init__.py new file mode 100644 index 0000000..9774958 --- /dev/null +++ b/training/torchtune/datasets/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset +from torchtune.datasets._chat import chat_dataset, ChatDataset +from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset +from torchtune.datasets._concat import ConcatDataset +from torchtune.datasets._grammar import grammar_dataset +from torchtune.datasets._instruct import instruct_dataset, InstructDataset +from torchtune.datasets._packed import PackedDataset +from torchtune.datasets._preference import PreferenceDataset +from torchtune.datasets._samsum import samsum_dataset +from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._slimorca import slimorca_dataset +from torchtune.datasets._stack_exchanged_paired import stack_exchanged_paired_dataset +from torchtune.datasets._text_completion import ( + text_completion_dataset, + TextCompletionDataset, +) +from torchtune.datasets._wikitext import wikitext_dataset + +__all__ = [ + "alpaca_dataset", + "alpaca_cleaned_dataset", + "grammar_dataset", + "samsum_dataset", + "stack_exchanged_paired_dataset", + "InstructDataset", + "slimorca_dataset", + "ChatDataset", + "instruct_dataset", + "chat_dataset", + "text_completion_dataset", + "TextCompletionDataset", + "cnn_dailymail_articles_dataset", + "PackedDataset", + "ConcatDataset", + "wikitext_dataset", + "PreferenceDataset", + "SFTDataset", +] diff --git a/training/torchtune/datasets/_alpaca.py b/training/torchtune/datasets/_alpaca.py new file mode 100644 index 0000000..a705f59 --- /dev/null +++ b/training/torchtune/datasets/_alpaca.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +from torchtune.datasets._instruct import instruct_dataset, InstructDataset + +from torchtune.modules.tokenizers import ModelTokenizer + + +def alpaca_dataset( + tokenizer: ModelTokenizer, + *, + source: str = "tatsu-lab/alpaca", + train_on_input: bool = True, + max_seq_len: int = 512, + packed: bool = False, + split: str = "train", +) -> InstructDataset: + """ + Support for family of Alpaca-style datasets from Hugging Face Datasets using + the `data input format `_ + and `prompt template `_ + from the original alpaca codebase, where ``instruction``, ``input``, and ``output`` + are fields from the dataset. + + Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is + set to ``True`` by `default `_ + - If ``train_on_input`` is True, the prompt is used during training and + contributes to the loss. + - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. + Default is 512, but we recommend setting this to the highest you can fit in memory and + is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + Returns: + InstructDataset: dataset configured with source data and template + + + Example: + >>> alpaca_ds = alpaca_dataset(tokenizer=tokenizer) + >>> for batch in Dataloader(alpaca_ds, batch_size=8): + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + + return instruct_dataset( + tokenizer=tokenizer, + source=source, + template="torchtune.data.AlpacaInstructTemplate", + train_on_input=train_on_input, + max_seq_len=max_seq_len, + packed=packed, + split=split, + ) + + +alpaca_cleaned_dataset = partial(alpaca_dataset, source="yahma/alpaca-cleaned") +alpaca_cleaned_dataset.__doc__ = """ +Builder for a variant of Alpaca-style datasets with the cleaned version of the +original Alpaca dataset, `yahma/alpaca-cleaned `_. +See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details. +""" diff --git a/training/torchtune/datasets/_chat.py b/training/torchtune/datasets/_chat.py new file mode 100644 index 0000000..323f1df --- /dev/null +++ b/training/torchtune/datasets/_chat.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, List, Mapping, Optional + +import numpy as np + +from datasets import load_dataset +from torch.utils.data import Dataset +from torchtune.config._utils import _get_component_from_path +from torchtune.data import ( + ChatFormat, + CROSS_ENTROPY_IGNORE_IDX, + get_openai_messages, + get_sharegpt_messages, + Message, + validate_messages, +) +from torchtune.datasets._packed import PackedDataset +from torchtune.modules.tokenizers import ModelTokenizer + + +class ChatDataset(Dataset): + """ + Class that supports any custom dataset with multiturn conversations. + + The general flow from loading a sample to tokenized prompt is: + load sample -> apply transform -> foreach turn{format into template -> tokenize} + + Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format + and roles:: + + [ + Message( + role=, + content=, + ), + ... + ] + + This class supports multi-turn conversations. If a tokenizer sample with multiple + turns does not fit within ``max_seq_len`` then it is truncated. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample + and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys + chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual + messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not + as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, + unless you want to structure messages in a particular way for inference. + max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, + such as ``data_files`` or ``split``. + + Raises: + ValueError: if ``chat_format`` is not an instance of :class:`torchtune.data.ChatFormat`. + """ + + def __init__( + self, + *, + tokenizer: ModelTokenizer, + source: str, + convert_to_messages: Callable[[Mapping[str, Any]], List[Message]], + chat_format: Optional[ChatFormat] = None, + max_seq_len: int, + train_on_input: bool = False, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + if chat_format is not None and not isinstance(chat_format(), ChatFormat): + raise ValueError( + f"chat_format must be a ChatFormat class, not {type(chat_format())}" + ) + + self._tokenizer = tokenizer + self._data = load_dataset(source, **load_dataset_kwargs) + self._convert_to_messages = convert_to_messages + self.chat_format = chat_format + self.max_seq_len = max_seq_len + self.train_on_input = train_on_input + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, List[int]]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: + messages = self._convert_to_messages(sample, self.train_on_input) + if self.chat_format is not None: + messages = self.chat_format.format(messages) + validate_messages(messages) + tokens, mask = self._tokenizer.tokenize_messages( + messages, + ) + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) + assert len(tokens) == len(labels) + + return {"tokens": tokens, "labels": labels} + + +def chat_dataset( + tokenizer: ModelTokenizer, + *, + source: str, + conversation_style: str, + chat_format: Optional[str] = None, + max_seq_len: int, + train_on_input: bool = False, + packed: bool = False, + **load_dataset_kwargs: Dict[str, Any], +) -> ChatDataset: + """ + Build a configurable dataset with conversations. This method should be + used to configure a custom chat dataset from the yaml config instead of + using :class:`~torchtune.datasets.ChatDataset` directly, as it is made to be config friendly. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + conversation_style (str): string specifying expected style of conversations in the dataset + for automatic conversion to the :class:`~torchtune.data.Message` structure. Supported styles are: "sharegpt", "openai" + chat_format (Optional[str]): full import path of :class:`~torchtune.data.ChatFormat` class used to format the messages. + See the description in :class:`~torchtune.datasets.ChatDataset` for more details. For a list of all + possible chat formats, check out :ref:`chat_formats`. Default: None. + max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, + such as ``data_files`` or ``split``. + + Examples: + >>> from torchtune.datasets import chat_dataset + >>> dataset = chat_dataset( + ... tokenizer=tokenizer, + ... source="HuggingFaceH4/no_robots", + ... conversation_style="sharegpt", + ... chat_format="torchtune.data.ChatMLFormat", + ... max_seq_len=2096, + ... train_on_input=True + ... ) + + This can also be accomplished via the yaml config:: + + dataset: + _component_: torchtune.datasets.chat_dataset + source: HuggingFaceH4/no_robots + conversation_style: sharegpt + chat_format: torchtune.data.ChatMLFormat + max_seq_len: 2096 + train_on_input: True + + Returns: + ChatDataset or PackedDataset: the configured :class:`~torchtune.datasets.ChatDataset` + or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` + + Raises: + ValueError: if the conversation format is not supported + """ + if conversation_style == "sharegpt": + convert_to_messages = get_sharegpt_messages + elif conversation_style == "openai": + convert_to_messages = get_openai_messages + else: + raise ValueError(f"Unsupported conversation style: {conversation_style}") + + ds = ChatDataset( + tokenizer=tokenizer, + source=source, + convert_to_messages=convert_to_messages, + chat_format=_get_component_from_path(chat_format) + if chat_format is not None + else None, + max_seq_len=max_seq_len, + train_on_input=train_on_input, + **load_dataset_kwargs, + ) + return ( + PackedDataset(ds, max_seq_len=max_seq_len, padding_idx=tokenizer.pad_id) + if packed + else ds + ) diff --git a/training/torchtune/datasets/_cnn_dailymail.py b/training/torchtune/datasets/_cnn_dailymail.py new file mode 100644 index 0000000..75ac146 --- /dev/null +++ b/training/torchtune/datasets/_cnn_dailymail.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional + +from torchtune.datasets._text_completion import TextCompletionDataset + +from torchtune.modules.tokenizers import ModelTokenizer + + +def cnn_dailymail_articles_dataset( + tokenizer: ModelTokenizer, + source: str = "ccdv/cnn_dailymail", + max_seq_len: Optional[int] = None, + split: str = "train", + **load_dataset_kwargs: Dict[str, Any], +) -> TextCompletionDataset: + """ + Support for family of datasets similar to `CNN / DailyMail `_, + a corpus of news articles. This builder only extracts the articles and not the highlights for + general text completion tasks. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + + Returns: + TextCompletionDataset: the configured TextCompletionDataset + """ + + return TextCompletionDataset( + tokenizer=tokenizer, + source=source, + column="article", + max_seq_len=max_seq_len, + split=split, + # This is used to specify the version of the dataset, a required argument + # by the cnn_dailymail dataset builder: + # https://huggingface.co/datasets/ccdv/cnn_dailymail/blob/main/cnn_dailymail.py#L80 + name="3.0.0", + **load_dataset_kwargs, + ) diff --git a/training/torchtune/datasets/_concat.py b/training/torchtune/datasets/_concat.py new file mode 100644 index 0000000..2a76602 --- /dev/null +++ b/training/torchtune/datasets/_concat.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +from torch.utils.data import Dataset + +from torchtune import utils + +log = utils.get_logger("DEBUG") + + +class ConcatDataset(Dataset): + """ + A dataset class for concatenating multiple sub-datasets into a single dataset. This class enables the + unified handling of different datasets as if they were a single dataset, simplifying tasks such as + training models on multiple sources of data simultaneously. + + The class internally manages the aggregation of different datasets and allows transparent indexing across them. + However, it requires all constituent datasets to be fully loaded into memory, which might not be optimal for + very large datasets. + + Upon initialization, this class computes the cumulative length of all datasets and maintains an internal mapping + of indices to the respective datasets. This approach allows the :class:`~torchtune.datasets.ConcatDataset` + to delegate data retrieval to the appropriate sub-dataset transparently when a particular index is accessed. + + Note: + Using this class with very large datasets can lead to high memory consumption, as it requires all datasets to + be loaded into memory. For large-scale scenarios, consider other strategies that might stream data on demand. + + Args: + datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class + derived from :class:`~torch.utils.data.Dataset`. + + Examples: + >>> dataset1 = MyCustomDataset(params1) + >>> dataset2 = MyCustomDataset(params2) + >>> concat_dataset = ConcatDataset([dataset1, dataset2]) + >>> print(len(concat_dataset)) # Total length of both datasets + >>> data_point = concat_dataset[1500] # Accesses an element from the appropriate dataset + + This can also be accomplished by passing in a list of datasets to the YAML config:: + + dataset: + - _component_: torchtune.datasets.instruct_dataset + source: vicgalle/alpaca-gpt4 + template: torchtune.data.AlpacaInstructTemplate + split: train + train_on_input: True + - _component_: torchtune.datasets.instruct_dataset + source: samsum + template: torchtune.data.SummarizeTemplate + column_map: {"output": "summary"} + output: summary + split: train + train_on_input: False + + This class primarily focuses on providing a unified interface to access elements from multiple datasets, + enhancing the flexibility in handling diverse data sources for training machine learning models. + """ + + def __init__(self, datasets: List[Dataset]): + self._datasets: List[Dataset] = datasets + self._len: int = sum(len(dataset) for dataset in datasets) + self._indexes: List[Tuple[int, int, int]] = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len diff --git a/training/torchtune/datasets/_grammar.py b/training/torchtune/datasets/_grammar.py new file mode 100644 index 0000000..d42ff34 --- /dev/null +++ b/training/torchtune/datasets/_grammar.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict, Optional, Union + +from torchtune.data import InputOutputToMessages +from torchtune.data._prompt_templates import ( + GrammarErrorCorrectionTemplate, + PromptTemplate, +) +from torchtune.datasets._packed import PackedDataset +from torchtune.datasets._sft import SFTDataset +from torchtune.modules.transforms import Transform + + +def grammar_dataset( + model_transform: Transform, + *, + source: str = "liweili/c4_200m", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = GrammarErrorCorrectionTemplate(), + train_on_input: bool = False, + packed: bool = False, + split: str = "train", +) -> Union[SFTDataset, PackedDataset]: + """ + Support for grammar correction datasets and their variants from Hugging Face Datasets. + Here is an `example `_ of a grammar correction dataset. + + The prompt template mirrors what is used in the `llama_recipes codebase + `_ + + where ``input`` and ``output`` are fields from the dataset. + + Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is + set to ``False`` by default + - If ``train_on_input`` is True, the prompt is used during training and + contributes to the loss. + - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) + + Args: + model_transform (Transform): model specific transform to convert a list of messages + output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. Default is ``liweili/c4_200m``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is :class:`~torchtune.data.GrammarErrorCorrectionTemplate`. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + + Returns: + Union[SFTDataset, PackedDataset]: dataset configured with source data and template + + + Example: + >>> grammar_ds = grammar_dataset(tokenizer=tokenizer) + >>> for batch in Dataloader(grammar_ds, batch_size=8): + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + + message_transform = InputOutputToMessages( + train_on_input=train_on_input, column_map=column_map + ) + ds = SFTDataset( + source=source, + message_transform=message_transform, + model_transform=model_transform, + prompt_template=prompt_template, + split=split, + ) + return PackedDataset(ds) if packed else ds diff --git a/training/torchtune/datasets/_instruct.py b/training/torchtune/datasets/_instruct.py new file mode 100644 index 0000000..cb00066 --- /dev/null +++ b/training/torchtune/datasets/_instruct.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, List, Mapping, Optional + +import numpy as np +from datasets import load_dataset +from torch.utils.data import Dataset +from torchtune.config._utils import _get_component_from_path +from torchtune.data import ( + CROSS_ENTROPY_IGNORE_IDX, + InstructTemplate, + Message, + validate_messages, +) +from torchtune.datasets._packed import PackedDataset +from torchtune.modules.tokenizers import ModelTokenizer + + +class InstructDataset(Dataset): + """ + Class that supports any custom dataset with instruction-based prompts and a + configurable template. + + The general flow from loading a sample to tokenized prompt is: + load sample -> apply transform -> format into template -> tokenize + + If the column/key names differ from the expected names in the :class:`~torchtune.data.InstructTemplate`, + then the ``column_map`` argument can be used to provide this mapping. + + Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is + set to ``False`` by default. + - If ``train_on_input`` is True, the prompt is used during training and + contributes to the loss. + - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + template (InstructTemplate): template used to format the prompt. If the placeholder variable + names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. + transform (Optional[Callable]): transform to apply to the sample before formatting to the template. + Default is None. + column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template + to the column/key names in the sample. If None, assume these are identical. + The output column can be indicated using the ``output`` key mapping. + If no placeholder for the ``output`` column is provided in ``column_map`` it is assumed to be ``output``. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, + such as ``data_files`` or ``split``. + Raises: + ValueError: If ``template`` is not an instance of :class:`torchtune.data.InstructTemplate` + """ + + def __init__( + self, + tokenizer: ModelTokenizer, + source: str, + template: InstructTemplate, + transform: Optional[Callable] = None, + column_map: Optional[Dict[str, str]] = None, + train_on_input: bool = False, + max_seq_len: Optional[int] = None, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + if not isinstance(template(), InstructTemplate): + raise ValueError( + f"template must be an InstructTemplate class, not {type(template())}" + ) + + self._tokenizer = tokenizer + self._data = load_dataset(source, **load_dataset_kwargs) + self.template = template + self._transform = transform + self._column_map = column_map + self.train_on_input = train_on_input + self.max_seq_len = max_seq_len + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, List[int]]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: + transformed_sample = self._transform(sample) if self._transform else sample + + prompt = self.template.format(transformed_sample, self._column_map) + key_output = ( + self._column_map["output"] + if self._column_map and "output" in self._column_map + else "output" + ) + messages = [ + Message(role="user", content=prompt, masked=(not self.train_on_input)), + Message(role="assistant", content=transformed_sample[key_output]), + ] + + validate_messages(messages) + + tokens, mask = self._tokenizer.tokenize_messages( + messages, + ) + + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) + assert len(tokens) == len(labels) + + return {"tokens": tokens, "labels": labels} + + +def instruct_dataset( + tokenizer: ModelTokenizer, + *, + source: str, + template: str, + column_map: Optional[Dict[str, str]] = None, + train_on_input: bool = False, + max_seq_len: Optional[int] = None, + packed: bool = False, + **load_dataset_kwargs: Dict[str, Any], +) -> InstructDataset: + """ + Build a configurable dataset with instruction prompts. This method should be + used to configure a custom instruct dataset from the yaml config instead of + using :class:`~torchtune.datasets.InstructDataset` directly, as it is made to be config friendly. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + template (str): full import path of class used to format the prompt. If the placeholder variable + names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. + column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template + to the column/key names in the sample. If None, assume these are identical. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, + such as ``data_files`` or ``split``. + + Examples: + >>> from torchtune.datasets import instruct_dataset + >>> dataset = instruct_dataset( + ... tokenizer=tokenizer, + ... source="yahma/alpaca_cleaned", + ... template="torchtune.data.AlpacaInstructTemplate", + ... max_seq_len=2096, + ... train_on_input=True, + ... packed=True, + ... ) + + This can also be accomplished via the yaml config:: + + dataset: + _component_: torchtune.datasets.instruct_dataset + source: yahma/alpaca_cleaned + template: torchtune.data.AlpacaInstructTemplate + max_seq_len: 2096 + train_on_input: True + packed: True + + Returns: + InstructDataset or PackedDataset: the configured :class:`~torchtune.datasets.InstructDataset` + or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` + """ + ds = InstructDataset( + tokenizer=tokenizer, + source=source, + template=_get_component_from_path(template), + column_map=column_map, + train_on_input=train_on_input, + max_seq_len=max_seq_len, + **load_dataset_kwargs, + ) + return ( + PackedDataset(ds, max_seq_len=max_seq_len, padding_idx=tokenizer.pad_id) + if packed + else ds + ) diff --git a/training/torchtune/datasets/_packed.py b/training/torchtune/datasets/_packed.py new file mode 100644 index 0000000..fb41b98 --- /dev/null +++ b/training/torchtune/datasets/_packed.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Union + +import torch +from torch.nn import functional as F + +from torch.utils.data import Dataset +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.utils import get_world_size_and_rank +from tqdm import tqdm + +PACK_TYPE = Dict[str, Union[torch.Tensor, List[int]]] + + +class PackedDataset(Dataset): + """ + Performs greedy sample packing on a provided dataset. This is done as a single + preprocessing step before training begins. Shuffling is done outside of this + class on packed samples with a ``Sampler`` as part of the dataloader. Currently, + this only supports in-memory map-style datasets. + + The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training. + + The general flow on initialization is: load tokenized sample -> add to buffer -> + when buffer is long enough, add to ``self.packs``. + + During training, returns self.packs[idx] as input, label, attention mask, and + position ids. The attention mask is a lower triangular block mask to prevent + samples from cross-attending within a pack. The position ids indicate the position + of each token relative to its sample within a pack. These are all padded to max + sequence length, so a batch-wise collator is not needed. + + A packed sample is made up of individual smaller sequence length samples jammed together + within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied + length samples:: + + tokens = [ + [S1, S1, S1, S2, S2, pad], + [S3, S3, S4, S4, pad, pad], + ..., + ] + + To prevent cross-contamination, the following mask would be returned for the + first pack in the example:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + The position ids would be:: + + input_pos = [ + [0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 2, 3], + ..., + ] + + The identity matrix is used in the mask for pad tokens instead of a causal mask. + For position ids for pad tokens, we simply continue to increment from the previous + sample normally. + + Args: + ds (Dataset): dataset to sample pack. This should return a dictionary with field + "tokens" and "labels" containing the tokenized and label samples. + max_seq_len (int): Maximum number of tokens to pack + padding_idx (int): padding index for the tokenizer. Default is 0. + max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many + packs as possible. + split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, + split the sample into the next pack, or move it entirely to the beginning of the next pack. + For pre-training, typically this is set to True for general text completion. For + fine-tuning, typically this is set to False to avoid truncating sentences in instruct + tuning. Default is False. + """ + + def __init__( + self, + ds: Dataset, + *, + max_seq_len: int, + padding_idx: int = 0, + max_packs: Optional[int] = None, + split_across_pack: bool = False, + ) -> None: + self.ds = ds + self.max_seq_len = max_seq_len + self.padding_idx = padding_idx + self.max_packs = max_packs + self.split_across_pack = split_across_pack + # Where final samples will be held + self.packs: List[PACK_TYPE] = [] + self.previous_sample_boundary: int = 0 + self._pack() + + def _pack(self) -> None: + """Iterate through the dataset. Use a buffer to hold samples until max_seq_len, + then append the buffer to self.packs as a single "packed" sample. Continue + until max_packs or end of dataset.""" + # Buffer to hold samples until they are long enough to be added to self.packs + current_pack = { + "tokens": [], + "labels": [], + "input_pos": [], + "seq_lens": [], + } + + # Only show progress bar on rank 0 + _, rank = get_world_size_and_rank() + if rank == 0: + pbar = tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) + + for sample in self.ds: + tokens, labels = sample["tokens"], sample["labels"] + + # If the dataset outputs samples that are larger than the specified + # max_seq_len and we're unable to split it, user needs to modify + # one of the two parameters + seq_len = len(tokens) + if seq_len > self.max_seq_len and not self.split_across_pack: + raise ValueError( + f"Dataset sample is too long ({seq_len} > {self.max_seq_len}). " + "Please set `split_across_pack=True` or increase `max_seq_len`." + ) + + # Update the current pack + current_pack["tokens"] += tokens + current_pack["labels"] += labels + current_pack["input_pos"] += list(range(seq_len)) + current_pack["seq_lens"] += [seq_len] + + # If the current pack is over the max_seq_len, add it to self.packs and + # retain any truncated or bumped samples for next pack + if len(current_pack["tokens"]) > self.max_seq_len: + current_pack = self._split_and_add_pack(current_pack) + + if rank == 0: + pbar.update() + + # Keep track of previous sample boundary + self.previous_sample_boundary = len(current_pack["tokens"]) + + # If max packs is set, stop packing when we reach that number + if self.max_packs is not None and len(self.packs) == self.max_packs: + break + + # Handle the last pack if there's leftover and we haven't filled up the max packs + if len(current_pack["tokens"]) > 0 and ( + self.max_packs is None or len(self.packs) < self.max_packs + ): + # No need to handle splitting at this point so we can just add the current pack + self._add_pack(current_pack) + + def _split_and_add_pack(self, current_pack: PACK_TYPE) -> PACK_TYPE: + """Splits the current pack at the boundary, processes it, adds it to ``self.packs`` and + returns the start of the next pack.""" + + if self.split_across_pack: + boundary = self.max_seq_len + # The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len`` + seq_len_padding = [self.max_seq_len - sum(current_pack["seq_lens"][:-1])] + else: + boundary = self.previous_sample_boundary + # If we aren't splitting across packs, we leave out the last sample b/c + # it will go into the next pack + seq_len_padding = [] + + pack = { + "tokens": current_pack["tokens"][:boundary], + "labels": current_pack["labels"][:boundary], + "input_pos": current_pack["input_pos"][:boundary], + "seq_lens": current_pack["seq_lens"][:-1] + seq_len_padding, + } + + # Process and add the pack + self._add_pack(pack) + + # Return the length of the first sample in next pack if we are splitting across packs, + # otherwise return the length of the last sample in the current pack + next_seq_len = ( + len(current_pack["tokens"][boundary:]) + if self.split_across_pack + else current_pack["seq_lens"][-1] + ) + + return { + "tokens": current_pack["tokens"][boundary:], + "labels": current_pack["labels"][boundary:], + "input_pos": current_pack["input_pos"][boundary:], + "seq_lens": [next_seq_len], + } + + def _add_pack(self, pack: PACK_TYPE) -> None: + """Processes, pads and adds a pack to ``self.packs``.""" + pack = self._convert_to_tensors(pack) + pack = self._pad_pack(pack, padding_idx=self.padding_idx) + self.packs.append(pack) + + def _convert_to_tensors(self, pack: PACK_TYPE) -> PACK_TYPE: + """Converts a pack into tensors. Pack comes in as a dict of lists and is converted to tensors. + The only key that does not get converted is ``seq_lens``. + """ + return { + "tokens": torch.tensor(pack["tokens"]), + "labels": torch.tensor(pack["labels"]), + "input_pos": torch.tensor(pack["input_pos"]), + "seq_lens": pack["seq_lens"], + } + + def _pad_pack(self, pack: PACK_TYPE, padding_idx: int) -> PACK_TYPE: + """Pads a pack to ``self.max_seq_len``.""" + # Pad tokens + padded_tokens = F.pad( + pack["tokens"], + (0, self.max_seq_len - len(pack["tokens"])), + value=padding_idx, + ) + + # Pad labels + padded_labels = F.pad( + pack["labels"], + (0, self.max_seq_len - len(pack["labels"])), + value=CROSS_ENTROPY_IGNORE_IDX, + ) + + # Pad input_pos continuing the sequence from last value + # in input_pos + # e.g. [0 1 2] -> [0 1 2 3 4 5] for self.max_seq_len = 6 + num_range = torch.arange( + pack["input_pos"][-1] + 1, + pack["input_pos"][-1] + self.max_seq_len - len(pack["input_pos"]) + 1, + ) + # Clamp to max_seq_len - 1 to avoid out of bounds error + clamped_num_range = torch.clamp(num_range, 0, self.max_seq_len - 1) + padded_input_pos = torch.cat([pack["input_pos"], clamped_num_range]) + + return { + "tokens": padded_tokens, + "labels": padded_labels, + "input_pos": padded_input_pos, + "seq_lens": pack["seq_lens"], # seq_len is untouched + } + + def __len__(self) -> int: + return len(self.packs) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Constructs the attention mask on-the-fly and returns whole sample.""" + current_pack = self.packs[idx] + + num_samples_in_pack = len(current_pack["seq_lens"]) + total_seq_len = 0 + + block_attn_masks = [] + + for i, seq_len in enumerate(current_pack["seq_lens"]): + total_seq_len += seq_len + + # Append lower triangular matrix for causal mask + block_attn_masks.append( + torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) + ) + + # If we're at the last sample and the total seq len is less than the max seq len, + # we need to pad with identity matrix for the remainder + if i == num_samples_in_pack - 1 and total_seq_len < self.max_seq_len: + block_attn_masks.append( + torch.eye( + self.max_seq_len - total_seq_len, + self.max_seq_len - total_seq_len, + dtype=torch.bool, + ) + ) + + return { + "tokens": current_pack["tokens"], + "labels": current_pack["labels"], + "input_pos": current_pack["input_pos"], + # Assemble the mask into a block causal matrix + "mask": torch.block_diag(*block_attn_masks), + } diff --git a/training/torchtune/datasets/_preference.py b/training/torchtune/datasets/_preference.py new file mode 100644 index 0000000..476b043 --- /dev/null +++ b/training/torchtune/datasets/_preference.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, List, Mapping, Optional + +import numpy as np +from datasets import load_dataset +from torch.utils.data import Dataset + +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, InstructTemplate, Message + +from torchtune.modules.tokenizers import ModelTokenizer + + +class PreferenceDataset(Dataset): + """ + Class that supports any custom dataset with instruction-based prompts and a + configurable template. + + The general flow from loading a sample to tokenized prompt is: + load sample -> apply transform -> format into template -> tokenize + + If the column/key names differ from the expected names in the :class:`~torchtune.data.InstructTemplate`, + then the ``column_map`` argument can be used to provide this mapping. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + template (InstructTemplate): template used to format the prompt. If the placeholder variable + names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. + transform (Optional[Callable]): transform to apply to the sample before formatting to the template. + Default is None. + column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template + to the column/key names in the sample. If None, assume these are identical. + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, + such as ``data_files`` or ``split``. + """ + + def __init__( + self, + tokenizer: ModelTokenizer, + source: str, + template: InstructTemplate, + transform: Optional[Callable] = None, + column_map: Optional[Dict[str, str]] = None, + max_seq_len: Optional[int] = None, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + self._tokenizer = tokenizer + self._data = load_dataset(source, **load_dataset_kwargs) + self.template = template + self._transform = transform + self._column_map = column_map + self.max_seq_len = max_seq_len + self._data = self._data.filter( + lambda x: len(x[column_map["prompt"]]) + len(x[column_map["chosen"]]) + <= max_seq_len + and len(x[column_map["prompt"]]) + len(x[column_map["rejected"]]) + <= max_seq_len + ) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, List[int]]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: + transformed_sample = self._transform(sample) if self._transform else sample + prompt = self.template.format(transformed_sample, self._column_map) + + column_map = self._column_map or {} + key_chosen = column_map.get("chosen", "chosen") + key_rejected = column_map.get("rejected", "rejected") + + chosen_message = [ + Message(role="user", content=prompt, masked=True), + Message(role="assistant", content=transformed_sample[key_chosen]), + ] + + rejected_message = [ + Message(role="user", content=prompt, masked=True), + Message(role="assistant", content=transformed_sample[key_rejected]), + ] + + # TODO: Trunction differs from original DPO repo + # in DPO: first truncate prompts, then responses + chosen_input_ids, c_masks = self._tokenizer.tokenize_messages( + chosen_message, + ) + chosen_labels = list( + np.where(c_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids) + ) + + rejected_input_ids, r_masks = self._tokenizer.tokenize_messages( + rejected_message, + ) + rejected_labels = list( + np.where(r_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids) + ) + + assert len(chosen_input_ids) == len(chosen_labels) + assert len(rejected_input_ids) == len(rejected_labels) + + batch = dict( + chosen_input_ids=chosen_input_ids, + chosen_labels=chosen_labels, + rejected_input_ids=rejected_input_ids, + rejected_labels=rejected_labels, + ) + + return batch diff --git a/training/torchtune/datasets/_samsum.py b/training/torchtune/datasets/_samsum.py new file mode 100644 index 0000000..52608a7 --- /dev/null +++ b/training/torchtune/datasets/_samsum.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict, Optional, Union + +from torchtune.data import InputOutputToMessages +from torchtune.data._prompt_templates import PromptTemplate, SummarizeTemplate +from torchtune.datasets._packed import PackedDataset +from torchtune.datasets._sft import SFTDataset +from torchtune.modules.transforms import Transform + + +def samsum_dataset( + model_transform: Transform, + *, + source: str = "Samsung/samsum", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = SummarizeTemplate(), + train_on_input: bool = False, + packed: bool = False, + split: str = "train", +) -> Union[SFTDataset, PackedDataset]: + """ + Support for summarization datasets and their variants from Hugging Face Datasets. + An example is the `SAMsum dataset `_. + + The prompt template mirrors what is used in the llama_recipes `codebase + `_ + + where ``dialogue`` and ``summary`` are fields from the dataset. + + Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is + set to ``False`` by default + - If ``train_on_input`` is True, the prompt is used during training and + contributes to the loss. + - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) + + Args: + model_transform (Transform): model specific transform to convert a list of messages + output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. Default is ``Samsung/samsum``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is :class:`~torchtune.data.GrammarErrorCorrectionTemplate`. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + + Returns: + Union[SFTDataset, PackedDataset]: dataset configured with source data and template + + + Example: + >>> samsum_ds = samsum_dataset(tokenizer=tokenizer) + >>> for batch in Dataloader(samsum_ds, batch_size=8): + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + column_map = column_map or {"input": "dialogue", "output": "summary"} + message_transform = InputOutputToMessages( + train_on_input=train_on_input, column_map=column_map + ) + ds = SFTDataset( + source=source, + message_transform=message_transform, + model_transform=model_transform, + prompt_template=prompt_template, + split=split, + ) + return PackedDataset(ds) if packed else ds diff --git a/training/torchtune/datasets/_sft.py b/training/torchtune/datasets/_sft.py new file mode 100644 index 0000000..cf8157a --- /dev/null +++ b/training/torchtune/datasets/_sft.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, Mapping, Optional + +import numpy as np + +from datasets import load_dataset +from torch.utils.data import Dataset +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, PromptTemplate +from torchtune.modules.transforms import Transform + + +class SFTDataset(Dataset): + """ + Primary class for creating any dataset for supervised fine-tuning either from + Hugging Face Hub, local files, or remote files. This class supports instruct, + chat, tool, or multimodal data for fine-tuning. At a high level, this class + will load the data from source and apply the following pre-processing steps + when a sample is retrieved: + + 1. Dataset-specific transform. This is typically unique to each dataset and extracts + the necessary columns into torchtune's :class:`~torchtune.data.Message` format, + a standardized API for all model tokenizers. + 2. If specified, apply a prompt template for the task you are fine-tuning for. + 3. Model-specific transform or tokenization + + + All datasets are formatted into a list of :class:`~torchtune.data.Message` + because for fine-tuning, datasets can be considered as "conversations" with the model, + or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to + a role: + + - ``"system"`` messages contain the system prompt + - ``"user"`` messages contain the input prompt into the model + - ``"assistant"`` messages are the response of the model and what you actually want + to train for and compute loss directly against + - ``"ipython"`` messages are the return from a tool call + + Chat datasets are multiple rounds of user-assistant messages. Instruct datasets + are typically a single round involving a specific instruction and the model's response. + Tool datasets are a type of chat dataset that includes ipython messages. Multimodal + datasets are a type of chat dataset that incorporates media into the user messages. + + The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer + APIs expect. The key component of this class that ensures any dataset is transformed + into this format is the ``message_transform``. This is a callable class that takes + in a sample dictionary - typically a single row from the source dataset - that + processes the sample in any configurable way to output a list of messages:: + + [ + Message( + role=, + content=, + ), + ... + ] + + For any custom dataset, use the ``message_transform`` to contain all pre-processing to + return the list of messages. + + Any model-specific pre-processing that needs to happen can be configured with the ``model_transform`` + parameter. This is another callable class that contains any custom logic tied to the + model you are fine-tuning and will carry over to inference. For example, text + image + multimodal datasets requires processing the images in a way specific to the vision + encoder being used by the model and is agnostic to the specific dataset. + + Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer` + can be treated as a ``model_transform`` since it uses the model-specific tokenizer to + transform the list of messages outputted from the ``message_transform`` into tokens + used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` + into ``model_transform``. + + Args: + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. + message_transform (Transform): callable that keys into the desired fields in the sample + and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list + of messages are stored in the ``"messages"`` key. + model_transform (Transform): callable that applies model-specific pre-processing to the sample after the list of + messages is created from ``message_transform``. This includes tokenization and any modality-specific + transforms. It is expected to return at minimum ``"tokens"`` and ``"mask"`` keys. + prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used + to add structured text around the actual messages. The structured text is used in three scenarios: + + - Task-specific templates to gear models for a particular task that it will expect after training + - Model-specific templates that are required whenever the model is prompted, such as the [INST] + tags in Llama2 and in Mistral + - Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` + + The extra text will still get tokenized as normal text, not as special tokens. + filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See + the Hugging Face `docs `_ for more + details. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging + Face's `API ref `_ + for more details. + """ + + def __init__( + self, + *, + source: str, + message_transform: Transform, + model_transform: Transform, + prompt_template: Optional[PromptTemplate] = None, + filter_fn: Optional[Callable] = None, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + self._message_transform = message_transform + self._prompt_template = prompt_template + self._model_transform = model_transform + + self._data = load_dataset(source, **load_dataset_kwargs) + if filter_fn is not None: + self._data = self._data.filter(filter_fn) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, Any]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + transformed_sample = self._message_transform(sample) + if self._prompt_template is not None: + transformed_sample["messages"] = self._prompt_template( + transformed_sample["messages"] + ) + tokenized_dict = self._model_transform(transformed_sample) + + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"], + ) + ) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + + return tokenized_dict diff --git a/training/torchtune/datasets/_slimorca.py b/training/torchtune/datasets/_slimorca.py new file mode 100644 index 0000000..cc08340 --- /dev/null +++ b/training/torchtune/datasets/_slimorca.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional, Union + +from torchtune.data import PromptTemplate, ShareGPTToMessages +from torchtune.datasets._packed import PackedDataset + +from torchtune.datasets._sft import SFTDataset +from torchtune.modules.transforms import Transform + + +def slimorca_dataset( + model_transform: Transform, + *, + source: str = "Open-Orca/SlimOrca-Dedup", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = None, + train_on_input: bool = False, + packed: bool = False, + split: str = "train", +) -> Union[SFTDataset, PackedDataset]: + """ + Support for `SlimOrca-style `_ + family of conversational datasets. + + Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is + set to ``False`` by default + - If ``train_on_input`` is True, the prompt is used during training and + contributes to the loss. + - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) + + Args: + model_transform (Transform): model specific transform to convert a list of messages + output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. Default is ``Open-Orca/SlimOrca-Dedup``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is None. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + + Returns: + Union[SFTDataset, PackedDataset]: dataset configured with SlimOrca source data + + Example: + >>> ds = slimorca_dataset(tokenizer=tokenizer, max_seq_len=10) + >>> for input, label in ds: + >>> print(input) + >>> print(label) + >>> + >>> Sample Output: + >>> [1, 351, 82, 391, 221, 220, 193, 12, 471, ..., 2] + >>> [-100, -100, -100, -100, -100, -100, -100, -100, 471, ..., 2] + """ + + message_transform = ShareGPTToMessages( + train_on_input=train_on_input, column_map=column_map + ) + ds = SFTDataset( + source=source, + message_transform=message_transform, + model_transform=model_transform, + prompt_template=prompt_template, + split=split, + ) + return PackedDataset(ds) if packed else ds diff --git a/training/torchtune/datasets/_stack_exchanged_paired.py b/training/torchtune/datasets/_stack_exchanged_paired.py new file mode 100644 index 0000000..a53e575 --- /dev/null +++ b/training/torchtune/datasets/_stack_exchanged_paired.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data import StackExchangedPairedTemplate +from torchtune.datasets._preference import PreferenceDataset +from torchtune.modules.tokenizers import ModelTokenizer + + +def stack_exchanged_paired_dataset( + tokenizer: ModelTokenizer, + *, + source: str = "lvwerra/stack-exchange-paired", + max_seq_len: int = 1024, + split: str = "train", +) -> PreferenceDataset: + """ + Family of preference datasets similar to `StackExchangePaired data + `_. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`. + max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. + Default is 1024. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + + Returns: + PreferenceDataset: The preference dataset built from source paired data. + """ + return PreferenceDataset( + tokenizer=tokenizer, + source=source, + template=StackExchangedPairedTemplate(), + column_map={ + "prompt": "question", + "chosen": "response_j", + "rejected": "response_k", + }, + max_seq_len=max_seq_len, + split=split, + data_dir="data/rl", + ) diff --git a/training/torchtune/datasets/_text_completion.py b/training/torchtune/datasets/_text_completion.py new file mode 100644 index 0000000..b3c4cf4 --- /dev/null +++ b/training/torchtune/datasets/_text_completion.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Mapping, Optional, Union + +from datasets import load_dataset +from torch.utils.data import Dataset +from torchtune.data import truncate +from torchtune.datasets._packed import PackedDataset +from torchtune.modules.tokenizers import ModelTokenizer + + +class TextCompletionDataset(Dataset): + """ + Freeform dataset for any unstructured text corpus. Quickly load any dataset + from Hugging Face or local disk and tokenize it for your model. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + column (str): name of column in the sample that contains the text data. This is typically required + for Hugging Face datasets or tabular data. For local datasets with a single column + (e.g. unstructured txt files), use the default "text" which is used by Hugging Face datasets + when loaded into memory. Default is "text". + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + add_eos (bool): Whether to add an EOS token to the end of the sequence. Default is True. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, + such as ``data_files`` or ``split``. + """ + + def __init__( + self, + tokenizer: ModelTokenizer, + source: str, + column: str = "text", + max_seq_len: Optional[int] = None, + add_eos: bool = True, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + self._tokenizer = tokenizer + self._data = load_dataset(source, **load_dataset_kwargs) + self.max_seq_len = max_seq_len + self._column = column + self.add_eos = add_eos + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, List[int]]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: + prompt = sample[self._column] + tokens = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=self.add_eos) + + # Truncate if needed, but don't coerce EOS id + if self.max_seq_len is not None: + tokens = truncate(tokens, self.max_seq_len - 1) + + # No need to offset labels by 1 - happens in the recipe + labels = tokens.copy() + + return {"tokens": tokens, "labels": labels} + + +def text_completion_dataset( + tokenizer: ModelTokenizer, + source: str, + column: str = "text", + max_seq_len: Optional[int] = None, + add_eos: bool = True, + packed: bool = False, + split_across_pack: bool = True, + **load_dataset_kwargs: Dict[str, Any], +) -> Union[TextCompletionDataset, PackedDataset]: + """ + Build a configurable dataset from a freeform, unstructured text corpus similar + to datasets used in pre-training. This method should be + used to configure a custom text dataset from the yaml config instead of + using :class:`~torchtune.datasets.TextCompletionDataset` directly, as it is made to be config friendly. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + column (str): name of column in the sample that contains the text data. This is typically required + for Hugging Face datasets or tabular data. For local datasets with a single column + (e.g. unstructured txt files), use the default "text" which is used by Hugging Face datasets + when loaded into memory. Default is "text". + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + add_eos (bool): Whether to add an EOS token to the end of the sequence. Default is True. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, + split the sample into the next pack, or move it entirely to the beginning of the next pack. + For pre-training, typically this is set to True for general text completion. For + fine-tuning, typically this is set to False to avoid truncating sentences in instruct + tuning. This argument is ignored if ``packed=False``. Default is True. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + + Examples: + >>> from torchtune.datasets import text_completion_dataset + >>> dataset = text_completion_dataset( + ... tokenizer=tokenizer, + ... source="allenai/c4", + ... column="text", + ... max_seq_len=2096, + ... data_dir="realnewslike", + ... packed=False, + ... ) + + This can also be accomplished via the yaml config:: + + dataset: + _component_: torchtune.datasets.text_completion_dataset + source: allenai/c4 + column: text + max_seq_len: 2096 + data_dir: realnewslike + packed: False + + Returns: + Union[TextCompletionDataset, PackedDataset]: the configured :class:`~torchtune.datasets.TextCompletionDataset` + or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` + """ + ds = TextCompletionDataset( + tokenizer=tokenizer, + source=source, + column=column, + max_seq_len=max_seq_len, + add_eos=add_eos, + **load_dataset_kwargs, + ) + return ( + PackedDataset( + ds, + max_seq_len=max_seq_len, + padding_idx=tokenizer.pad_id, + split_across_pack=split_across_pack, + ) + if packed + else ds + ) diff --git a/training/torchtune/datasets/_wikitext.py b/training/torchtune/datasets/_wikitext.py new file mode 100644 index 0000000..aea8609 --- /dev/null +++ b/training/torchtune/datasets/_wikitext.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Union + +from torchtune.datasets._packed import PackedDataset + +from torchtune.datasets._text_completion import ( + text_completion_dataset, + TextCompletionDataset, +) + +from torchtune.modules.tokenizers import ModelTokenizer + + +def wikitext_dataset( + tokenizer: ModelTokenizer, + source: str = "EleutherAI/wikitext_document_level", + subset: str = "wikitext-103-v1", + max_seq_len: Optional[int] = None, + packed: bool = False, + split: str = "train", + **load_dataset_kwargs: Dict[str, Any], +) -> Union[TextCompletionDataset, PackedDataset]: + """ + Support for family of datasets similar to `wikitext + `_, + an unstructured text corpus consisting of fulls articles from Wikipedia. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. + subset (str): name of subset of data to use, see the `wikitext page + `_ + for available subsets. Default is ``"wikitext-103-v1"``. + max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. + Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory + and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + + Returns: + Union[TextCompletionDataset, PackedDataset]: the configured :class:`~torchtune.datasets.TextCompletionDataset` + or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` + """ + + return text_completion_dataset( + tokenizer=tokenizer, + source=source, + column="page", + max_seq_len=max_seq_len, + name=subset, + split=split, + **load_dataset_kwargs, + ) diff --git a/training/torchtune/models/__init__.py b/training/torchtune/models/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/torchtune/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/torchtune/models/clip/__init__.py b/training/torchtune/models/clip/__init__.py new file mode 100644 index 0000000..ca7ecf2 --- /dev/null +++ b/training/torchtune/models/clip/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import clip_vision_encoder + +from ._position_embeddings import ( + TiledTokenPositionalEmbedding, + TilePositionalEmbedding, + TokenPositionalEmbedding, +) + +__all__ = [ + "clip_vision_encoder", + "TokenPositionalEmbedding", + "TiledTokenPositionalEmbedding", + "TilePositionalEmbedding", +] diff --git a/training/torchtune/models/clip/_component_builders.py b/training/torchtune/models/clip/_component_builders.py new file mode 100644 index 0000000..3e61110 --- /dev/null +++ b/training/torchtune/models/clip/_component_builders.py @@ -0,0 +1,101 @@ +from typing import List, Optional + +import torch +from torchtune.modules.vision_transformer import VisionTransformer, CLSProjection +from torchtune.models.clip._position_embeddings import TokenPositionalEmbedding, TiledTokenPositionalEmbedding, TilePositionalEmbedding + +import logging + +logger = logging.getLogger(__name__) + +def clip_vision_encoder( + tile_size: int, + patch_size: int, + embed_dim: int, + num_layers: int, + num_heads: int, + cls_output_dim: int = 512, + out_indices: Optional[List[int]] = None, + output_cls_projection: bool = False, + max_num_tiles: int = 4, + in_channels: int = 3, +) -> VisionTransformer: + """ + Builds the vision encoder associated with the clip model. This includes: + + - num_layers TransformerEncoderLayers + - positional embeddings + - CLS projection (optional) + + For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + embed_dim (int): The dimensionality of each patch embedding (token). + num_layers (int): The number of transformer layers. + num_heads (int): The number of attention heads in each transformer layer. + cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. + out_indices (Optional[List[int]]): The indices of hidden layers to return. + If provided, it will return the intermediate results of the transformer layers + before they go through a next layer. For example, ``out_indices=[0,3]`` will + return the tokens before they go through the first and fourth layers. + output_cls_projection (bool): If True, only the CLS token projection will be outputted, + instead of all tokens. Defaults to False. + max_num_tiles (int): The maximum number of tiles that can be processed. This is used to + determine the size of the positional embeddings. + in_channels (int): The number of image input channels. + + Returns: + A `VisionTransformer` object. + """ + + cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None + + # TODO (Felipe): Replace with torchtune native encoder module + mlp_ratio = 4.0 + transformer_layer = torch.nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=int(mlp_ratio * embed_dim), + dropout=0.0, + activation=torch.nn.SiLU(), + layer_norm_eps=1e-5, + batch_first=True, + norm_first=True, + bias=True) + + # position embeddings + if max_num_tiles == 1: + pre_tile_pos_embed = None + post_tile_pos_embed = None + token_pos_embedding = TokenPositionalEmbedding( + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size) + else: + pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + token_pos_embedding = TiledTokenPositionalEmbedding( + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size) + + return VisionTransformer( + num_layers=num_layers, + layer=transformer_layer, + token_pos_embedding=token_pos_embedding, + pre_tile_pos_embed=pre_tile_pos_embed, + post_tile_pos_embed=post_tile_pos_embed, + cls_projection=cls_projection, + out_indices=out_indices, + tile_size=tile_size, + patch_size=patch_size, + embed_dim=embed_dim, + in_channels=in_channels, + ) diff --git a/training/torchtune/models/clip/_model_builders.py b/training/torchtune/models/clip/_model_builders.py new file mode 100644 index 0000000..becef9e --- /dev/null +++ b/training/torchtune/models/clip/_model_builders.py @@ -0,0 +1,14 @@ +from torchtune.models.clip._transforms import CLIPImageTransform + +def _clip_vit_224_transform(): + image_transform = CLIPImageTransform( + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + tile_size=224, + possible_resolutions=None, + max_num_tiles=1, + resample="bilinear", + resize_to_max_canvas=True, + ) + + return image_transform diff --git a/training/torchtune/models/clip/_position_embeddings.py b/training/torchtune/models/clip/_position_embeddings.py new file mode 100644 index 0000000..05897aa --- /dev/null +++ b/training/torchtune/models/clip/_position_embeddings.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +import torch +from torch import nn + +# TODO (@Felipe): add load hooks + interpolation on positional encodings, +# so max_num_tiles can be variable and a trained model can be adapted to a +# new value. + + +class TokenPositionalEmbedding(nn.Module): + """ + Token positional embedding for images, different for every token in an image. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + embed_dim (int): The dimensionality of each token embedding. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + """ + + def __init__(self, embed_dim: int, tile_size: int, patch_size: int) -> None: + super().__init__() + patch_grid_size = tile_size // patch_size + scale = embed_dim**-0.5 + self.positional_embedding = nn.Parameter( + scale + * torch.randn((patch_grid_size**2 + 1, embed_dim)) # +1 for CLS token + ) + + def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor with shape (..., n_tokens, embed_dim) + *args (Tuple[Any]): Optional args. + + Returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + return x + self.positional_embedding + + +class TiledTokenPositionalEmbedding(nn.Module): + """ + + Token positional embedding for tiled images. There are two positional embeddings in this module: + + * local_token_positional_embedding: same for every tile, different for every token. Equivalent \ + to :class:`torchtune.models.clip._position_embeddings.TokenPositionalEmbedding`, but gated. + * global_token_positional_embedding: different for every tile, different for every token. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + embed_dim (int): The dimensionality of each token embedding. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + """ + + def __init__( + self, max_num_tiles: int, embed_dim: int, tile_size: int, patch_size: int + ) -> None: + super().__init__() + patch_grid_size = tile_size // patch_size + self.n_tokens_per_tile = patch_grid_size**2 + 1 # +1 for cls token + scale = embed_dim**-0.5 + + # different for every token, same for every tile + self.local_token_positional_embedding = nn.Parameter( + scale + * torch.randn((patch_grid_size**2 + 1, embed_dim)) # +1 for CLS token + ) + + # different for every token, different for every tile + self.global_token_positional_embedding = nn.Parameter( + scale + * torch.randn( + max_num_tiles, + max_num_tiles, + self.n_tokens_per_tile, + embed_dim, + ) + ) + + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2), + where aspect_ratio[k] represents the aspect ratio of the k^th image + of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1). + Returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + + # apply local position embedding (same for every tile) + x = x + (self.local_token_positional_embedding * (1 - self.gate.tanh())) + + # apply global positional embedding (different for every tile) + x = x.view(bsz_and_n_imgs, n_tiles, n_tokens, embed_dim) + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. n_tiles_h, n_tiles_w. + pos_embed = self.global_token_positional_embedding[ + :n_tiles_h, :n_tiles_w, :, : + ] + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.reshape( + n_non_padded_tiles, self.n_tokens_per_tile, embed_dim + ) + pos_embed = pos_embed * self.gate.tanh() + x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed + + return x + + +class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + embed_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + embed_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.embed_dim = embed_dim + + scale = embed_dim**-0.5 + self.embedding = nn.Parameter( + scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + args: + x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. n_tiles_h, n_tiles_w. + pos_embed = self.embedding[:n_tiles_h, :n_tiles_w, :, :] + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) + x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + + return x diff --git a/training/torchtune/models/clip/_transforms.py b/training/torchtune/models/clip/_transforms.py new file mode 100644 index 0000000..781e7ab --- /dev/null +++ b/training/torchtune/models/clip/_transforms.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, List, Mapping, Optional, Tuple + +import torch +import torchvision +from PIL import Image + +from torchtune.modules.transforms import ( + find_supported_resolutions, + get_canvas_best_fit, + resize_with_pad, + tile_crop, +) + +from torchvision.transforms.v2 import functional as F + +logger = logging.getLogger(__name__) + + +class CLIPImageTransform: + """ + This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it + based on the image aspect ratio and the number of image tiles we allow. + + The algorithm will NOT distort the image to fit a certain aspect ratio, because + that leads to a significant degradation in image quality. + + The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. + + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image tiles, with side 224px, then: + + If ``resize_to_max_canvas=False``, then: + best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling + image is NOT resized + image is padded (300, 800) -> 448,896 + Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) + + If ``resize_to_max_canvas=True``, then: + best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles + image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize + image is padded (448, 1194) -> (448, 1344) + Image is tiled 2x5, for a final output shape of (10, 3, 224, 224) + + Args: + image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). + where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. + If None, this will be calculated using max_num_tiles and tile_size. Default None. + tile_size (int): Size of the tiles to divide the image into. Default 224. + max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bilinear'. + resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible + resolution from possible_resolutions. + If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. + In this case, the image will only be upscaled if it's size < tile_size. Default False. + + Examples: + >>> image_transform = CLIPImageTransform( + ... image_mean=None, + ... image_std=None, + ... tile_size=224, + ... possible_resolutions=None, + ... max_num_tiles=4, + ... resample="bilinear", + ... resize_to_max_canvas=True, + ...) + >>> # create random image + >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) + >>> image = PIL.Image.fromarray(image) + >>> output = image_transform(image) + >>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] + torch.Size([2, 3, 224, 224]) + >>> output['ar'] # image best fits the canvas 224x448 + torch.tensor([1,2]) + """ + + def __init__( + self, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + possible_resolutions: Optional[List[Tuple[int, int]]] = None, + tile_size: int = 224, + max_num_tiles: Optional[int] = 4, + resample: str = "bilinear", + resize_to_max_canvas: bool = False, + ) -> None: + + # get_canvas_best_fit + assert ( + possible_resolutions is not None or max_num_tiles is not None + ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions=} and {max_num_tiles=}" + + # If possible_resolutions are not given, then calculate possible ones based on max_num_tiles + if not possible_resolutions and max_num_tiles: + possible_resolutions = find_supported_resolutions( + max_num_tiles=max_num_tiles, tile_size=tile_size + ) + else: + possible_resolutions = possible_resolutions + + self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) + logger.info( + f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." + ) + + self.resize_to_max_canvas = resize_to_max_canvas + + # normalize + assert (image_mean is None) == ( + image_std is None + ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" + self.image_mean = image_mean + self.image_std = image_std + + # resize_with_pad + self.max_upscaling_size = None if resize_to_max_canvas else tile_size + self.resample = torchvision.transforms.InterpolationMode[resample.upper()] + + # tile_crop + self.tile_size = tile_size + + def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]: + + assert isinstance(image, Image.Image), "Input image must be a PIL image." + + # Make image torch.tensor((3, H, W), dtype='float32'), 0<=values<=1 + image_tensor = F.to_dtype( + F.grayscale_to_rgb_image(F.to_image(image)), scale=True + ) + + # Find the best canvas to fit the image without distortion + best_resolution = get_canvas_best_fit( + image=image_tensor, + possible_resolutions=self.possible_resolutions, + resize_to_max_canvas=self.resize_to_max_canvas, + ) + + # resize without distortion + pad to fit best_resolution + image_tensor = resize_with_pad( + image=image_tensor, + target_size=best_resolution, + resample=self.resample, + max_upscaling_size=self.max_upscaling_size, + ) + + # Normalize + if self.image_mean and self.image_std: + image_tensor = F.normalize( + image_tensor, mean=self.image_mean, std=self.image_std + ) + + # Divide the image into equally sized tiles + image_tensor = tile_crop(image=image_tensor, tile_size=self.tile_size) + + aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size + + kwargs.update( + { + "image": image_tensor, + "aspect_ratio": aspect_ratio, + } + ) + + return kwargs diff --git a/training/torchtune/models/code_llama2/__init__.py b/training/torchtune/models/code_llama2/__init__.py new file mode 100644 index 0000000..b334f6f --- /dev/null +++ b/training/torchtune/models/code_llama2/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._model_builders import ( # noqa + code_llama2_13b, + code_llama2_70b, + code_llama2_7b, + lora_code_llama2_13b, + lora_code_llama2_70b, + lora_code_llama2_7b, + qlora_code_llama2_13b, + qlora_code_llama2_70b, + qlora_code_llama2_7b, +) + +__all__ = [ + "code_llama2_13b", + "code_llama2_70b", + "code_llama2_7b", + "lora_code_llama2_13b", + "lora_code_llama2_70b", + "lora_code_llama2_7b", + "qlora_code_llama2_13b", + "qlora_code_llama2_70b", + "qlora_code_llama2_7b", +] diff --git a/training/torchtune/models/code_llama2/_model_builders.py b/training/torchtune/models/code_llama2/_model_builders.py new file mode 100644 index 0000000..fca1736 --- /dev/null +++ b/training/torchtune/models/code_llama2/_model_builders.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from functools import partial + +from torchtune.models.llama2._component_builders import llama2, lora_llama2 + +from torchtune.modules import TransformerDecoder +from torchtune.modules.peft import LORA_ATTN_MODULES + + +def code_llama2_7b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model + """ + return llama2( + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 7B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_7b = partial(lora_code_llama2_7b, quantize_base=True) + +qlora_code_llama2_7b.__doc__ = """ +Builder for creating a Code-Llama2 7B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_7b` for full API arguments. +""" + + +def code_llama2_13b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 13B model + """ + return llama2( + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_13b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 13B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 13B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_13b = partial(lora_code_llama2_13b, quantize_base=True) + +qlora_code_llama2_13b.__doc__ = """ +Builder for creating a Code-Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_13b` for full API arguments. +""" + + +def code_llama2_70b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 70B model + """ + return llama2( + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 70B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_70b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 70B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_70b = partial(lora_code_llama2_70b, quantize_base=True) + +qlora_code_llama2_70b.__doc__ = """ +Builder for creating a Code-Llama2 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_70b` for full API arguments. +""" diff --git a/training/torchtune/models/convert_weights.py b/training/torchtune/models/convert_weights.py new file mode 100644 index 0000000..461090b --- /dev/null +++ b/training/torchtune/models/convert_weights.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re + +from typing import Any, Dict + +import torch + + +# state dict key mappings from Meta's format to torchtune's format +_FROM_META = { + "tok_embeddings.weight": "tok_embeddings.weight", + "norm.weight": "norm.scale", + "output.weight": "output.weight", + "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", + "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", + "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", + "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", + "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", +} + +# state dict key mappings from HF's format to torchtune's format +_FROM_HF = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + "lm_head.weight": "output.weight", +} + + +def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: + try: + if "layers" in key: + # Replace layer number with "{}" to create key for lookup + abstract_key = re.sub(r"(\.\d+)", ".{}", key) + layer_num = re.search(r"\d+", key).group(0) + new_key = mapping_dict[abstract_key] + new_key = new_key.format(layer_num) + else: + new_key = mapping_dict[key] + except KeyError as e: + raise Exception( + f'Error converting the state dict. Found unexpected key: "{key}". ' + "Please make sure you're loading a checkpoint with the right format. " + ) from e + + return new_key + + +def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from Meta's format to torchtune's format. State dicts + from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + + Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b`` + repo in HF (https://huggingface.co/meta-llama/Llama-2-7b). + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. + + Returns: + Dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + if key not in ["rope.freqs"]: # Skip loading the position embeddings + new_key = get_mapped_key(key, _FROM_META) + converted_state_dict[new_key] = value + + return converted_state_dict + + +def tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_META.items()} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + return converted_state_dict + + +def hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to torchtune's format. State dicts + from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + + Eg of HF-format state dict can be found in the ``meta-llama/Llama-2-7b-hf`` + repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf). + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + + Returns: + Dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings + new_key = get_mapped_key(key, _FROM_HF) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + + converted_state_dict[new_key] = value + return converted_state_dict + + +def tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +): + """ + Convert a state dict from torchtune's format to HF's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of model attention heads. Default None. + + Returns: + Dict[str, torch.Tensor]: State dict in HF's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} + + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, head_dim // 2, 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + + return converted_state_dict + + +# Mapping from torchtune LoRA module names to PEFT LoRA module names +_TO_PEFT_KEYS = { + "lora_a": "lora_A", + "lora_b": "lora_B", +} + +# Mapping from torchtune module names to target modules for PEFT adapter config +_TO_PEFT_TARGET_MODULES = { + "q_proj": "q_proj", + "k_proj": "k_proj", + "v_proj": "v_proj", + "output_proj": "o_proj", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "output": "lm_head", +} + +# Keys expected in PEFT's adapter_config.json +_PEFT_CONFIG_EXPECTED_KEYS = ["target_modules", "r", "lora_alpha"] + + +def tune_to_peft_adapter_config( + adapter_config: Dict[str, Any], +): + if not all([x in adapter_config.keys() for x in _PEFT_CONFIG_EXPECTED_KEYS]): + raise ValueError( + f"PEFT adapter config requires {_PEFT_CONFIG_EXPECTED_KEYS}, found {adapter_config.keys()}" + ) + + for k in adapter_config["target_modules"]: + if k not in _TO_PEFT_TARGET_MODULES: + raise ValueError(f"Unknown target module {k}") + adapter_config["target_modules"] = list( + map(_TO_PEFT_TARGET_MODULES.get, adapter_config["target_modules"]) + ) + + return adapter_config + + +def tune_to_peft_adapter_weights( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +): + converted_state_dict = {} + full_mapping = {} + # Rather than recreate a separate mapping for LoRA adapter weights, we just + # re-use the _FROM_HF mapping for base model weights. We iterate over it twice: + # once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices. + for k, v in _TO_PEFT_KEYS.items(): + full_mapping.update( + { + vv.replace(".weight", f".{k}.weight"): kk.replace( + ".weight", f".{v}.weight" + ) + for kk, vv in _FROM_HF.items() + if vv is not None + } + ) + + if head_dim is None: + head_dim = dim // num_heads + + def _permute_lora_matrix(t, n_heads): + rank = t.shape[-1] + return ( + t.view(n_heads, head_dim // 2, 2, rank) + .transpose(1, 2) + .reshape((head_dim * n_heads), rank) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, full_mapping) + if "q_proj" in new_key and "lora_B" in new_key: + value = _permute_lora_matrix(value, num_heads) + elif "k_proj" in new_key and "lora_B" in new_key: + value = _permute_lora_matrix(value, num_kv_heads) + converted_state_dict["base_model.model." + new_key] = value + return converted_state_dict diff --git a/training/torchtune/models/gemma/__init__.py b/training/torchtune/models/gemma/__init__.py new file mode 100644 index 0000000..48e4e84 --- /dev/null +++ b/training/torchtune/models/gemma/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import gemma, lora_gemma # noqa +from ._model_builders import ( # noqa + gemma_2b, + gemma_7b, + gemma_tokenizer, + lora_gemma_2b, + lora_gemma_7b, + qlora_gemma_2b, + qlora_gemma_7b, +) +from ._tokenizer import GemmaTokenizer # noqa + +__all__ = [ + "GemmaTokenizer", + "gemma", + "gemma_2b", + "gemma_7b", + "gemma_tokenizer", + "lora_gemma", + "lora_gemma_2b", + "lora_gemma_7b", + "qlora_gemma_2b", + "qlora_gemma_7b", + "gemma_hf_to_tune", + "gemma_tune_to_hf", +] diff --git a/training/torchtune/models/gemma/_component_builders.py b/training/torchtune/models/gemma/_component_builders.py new file mode 100644 index 0000000..36e8ec4 --- /dev/null +++ b/training/torchtune/models/gemma/_component_builders.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +from typing import List +from functools import partial +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + FrozenNF4Linear, + RotaryPositionalEmbeddings, + TransformerDecoderLayer, +) +from torchtune.models.gemma.rms_norm import GemmaRMSNorm +from torchtune.models.gemma.transformer import GemmaTransformerDecoder + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Gemma 2B models and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +def gemma( + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-6, + rope_base: int = 10_000, + norm_embeddings: bool = True, +) -> GemmaTransformerDecoder: + """ + Build the decoder associated with the gemma model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + This does NOT currently include inference-time optimizations such as + sliding-window attention + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + head_dim (int): dimension of head + num_kv_heads (int): number of key and value heads. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms Default: 1e-6 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + norm_embeddings (bool): whether to apply layer norm before the self-attention + and mlp layers. Default: True + + Returns: + GemmaTransformerDecoder: Instantiation of gemma model. + """ + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_att = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_att, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + model = GemmaTransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + norm_embeddings=norm_embeddings, + ) + return model + + +def gemma_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: + """ + Build the MLP layer associated with the Gemma model. + + Args: + dim (int): input dimension to the MLP + hidden_dim (int): hidden dimension of the MLP + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + activation = nn.GELU(approximate="tanh") + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation) + + +def lora_gemma( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + *, + # gemma args + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-6, + rope_base: int = 10_000, + norm_embeddings: bool = True, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> GemmaTransformerDecoder: + """ + Return a version of Gemma with LoRA applied based on the passed in configuration. + Note: output projection lora is not supported because it is tied to token embeddings + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + head_dim (int): dimension of head + num_kv_heads (int): number of key and value heads. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms Default: 1e-6 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + norm_embeddings (bool): whether to apply layer norm before the self-attention + and mlp layers. Default: True + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + GemmaTransformerDecoder: Instantiation of Gemma model with LoRA applied to + a subset of the attention projections in each layer. + """ + self_attn = lora_gemma_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + head_dim=head_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_gemma_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + model = GemmaTransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + norm_embeddings=norm_embeddings, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model + + +def lora_gemma_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 10_000, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + LoRALinear( + num_heads * head_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(num_heads * head_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False) + ) + ) + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_gemma_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + activation = nn.GELU(approximate="tanh") + + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation) diff --git a/training/torchtune/models/gemma/_model_builders.py b/training/torchtune/models/gemma/_model_builders.py new file mode 100644 index 0000000..43f0054 --- /dev/null +++ b/training/torchtune/models/gemma/_model_builders.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional + +from torchtune.models.gemma._component_builders import gemma, lora_gemma +from torchtune.models.gemma.transformer import GemmaTransformerDecoder + +from torchtune.models.gemma._tokenizer import GemmaTokenizer +from torchtune.modules.peft import LORA_ATTN_MODULES + +from functools import partial + +""" +Model builders build specific instantiations using component builders. For example +the ``gemma_2b`` model builder uses the ``gemma`` component builder. +""" + + +def gemma_2b() -> GemmaTransformerDecoder: + """ + Builder for creating a Gemma 2B model initialized w/ the default 2b parameter values + from: https://blog.google/technology/developers/gemma-open-models/ + + Returns: + GemmaTransformerDecoder: Instantiation of Gemma 2B model + """ + return gemma( + vocab_size=256_000, + num_layers=18, + num_heads=8, + head_dim=256, + num_kv_heads=1, + embed_dim=2048, + intermediate_dim=16384, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + ) + + +def gemma_tokenizer(path: str, max_seq_len: Optional[int] = None) -> GemmaTokenizer: + """ + Tokenizer for Gemma. + + Args: + path (str): path to the tokenizer + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + + Returns: + GemmaTokenizer: Instantiation of the Gemma tokenizer + """ + return GemmaTokenizer(path=path, max_seq_len=max_seq_len) + + +def lora_gemma_2b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> GemmaTransformerDecoder: + """ + Builder for creating a Gemma 2B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_2b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + GemmaTransformerDecoder: Instantiation of Gemma 2B model with LoRA applied + """ + return lora_gemma( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=18, + num_heads=8, + head_dim=256, + num_kv_heads=1, + embed_dim=2048, + intermediate_dim=16384, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + +qlora_gemma_2b = partial(lora_gemma_2b, quantize_base=True) + +qlora_gemma_2b.__doc__ = """ +Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemma_2b` for full API arguments. +""" + + + +def gemma_7b() -> GemmaTransformerDecoder: + """ + Builder for creating a Gemma 7B model initialized w/ the default 7b parameter values + from: https://blog.google/technology/developers/gemma-open-models/ + + Returns: + GemmaTransformerDecoder: Instantiation of Gemma 7B model + """ + return gemma( + vocab_size=256_000, + num_layers=28, + num_heads=16, + head_dim=256, + num_kv_heads=16, + embed_dim=3072, + intermediate_dim=24576, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + ) + + +def lora_gemma_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> GemmaTransformerDecoder: + """ + Builder for creating a Gemma 7B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + GemmaTransformerDecoder: Instantiation of Gemma 7B model with LoRA applied + """ + return lora_gemma( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=28, + num_heads=16, + head_dim=256, + num_kv_heads=16, + embed_dim=3072, + intermediate_dim=24576, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + +qlora_gemma_7b = partial(lora_gemma_7b, quantize_base=True) + +qlora_gemma_7b.__doc__ = """ +Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemma_7b` for full API arguments. +""" diff --git a/training/torchtune/models/gemma/_tokenizer.py b/training/torchtune/models/gemma/_tokenizer.py new file mode 100644 index 0000000..b2b3cd7 --- /dev/null +++ b/training/torchtune/models/gemma/_tokenizer.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +from torchtune.data import Message +from torchtune.modules.tokenizers import ( + ModelTokenizer, + SentencePieceBaseTokenizer, + tokenize_messages_no_special_tokens, +) +from torchtune.modules.transforms import Transform + +WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] + + +class GemmaTokenizer(ModelTokenizer, Transform): + """ + Gemma's implementation of the SentencePiece tokenizer + + Args: + path (str): Path to pretrained tokenizer file. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + + Examples: + >>> tokenizer = GemmaTokenizer("/path/to/spm_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + max_seq_len: Optional[int] = None, + ): + self._spm_model = SentencePieceBaseTokenizer(path) + + # Original tokenizer has no pad_id, which causes indexing errors when batch training + self._spm_model.pad_id = 0 + + # During generation, stop when eos_id is encountered + self.stop_tokens = [self.eos_id] + + self.max_seq_len = max_seq_len + + @property + def eos_id(self): + return self._spm_model.eos_id + + @property + def bos_id(self): + return self._spm_model.bos_id + + @property + def pad_id(self): + return self._spm_model.pad_id + + @property + def vocab_size(self): + return self._spm_model.vocab_size + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + trim_leading_whitespace: bool = False, + ) -> List[int]: + return self._spm_model.encode( + text, + add_bos=add_bos, + add_eos=add_eos, + trim_leading_whitespace=trim_leading_whitespace, + ) + + def decode( + self, + token_ids: List[int], + ) -> str: + return self._spm_model.decode(token_ids) + + def tokenize_messages( + self, + messages: List[Message], + ) -> Tuple[List[int], List[bool]]: + r"""Tokenize a list of messages one at a time then concatenate them, + returning a list of tokens and a list of masks. + + + Example: + >>> tokenizer = GemmaTokenizer(tokenizer_path, max_seq_len) + >>> messages = [ + Message(role="system", content="system message\n", masked=True), + Message(role="user", content="user prompt\n", masked=True), + Message(role="assistant", content="assistant response\n"), + ] + + >>> # tokenize_messages encodes messages separately and concats + >>> tokenizer.tokenize_messages(messages)[0] + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + >>> # Same result as encoding the full string in one go + >>> tokenizer.encode(''.join([message.content for message in messages])) + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + Args: + messages (List[Message]): A list of messages, each containing role, content, + and masked attributes. + + Returns: + Tuple[List[int], List[bool]]: The tokenized messages + """ + return tokenize_messages_no_special_tokens( + tokenizer=self, + messages=messages, + bos_id=self.bos_id, + eos_id=self.eos_id, + max_seq_len=self.max_seq_len, + ) + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/training/torchtune/models/gemma/rms_norm.py b/training/torchtune/models/gemma/rms_norm.py new file mode 100644 index 0000000..c803b5f --- /dev/null +++ b/training/torchtune/models/gemma/rms_norm.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +class GemmaRMSNorm(nn.Module): + # Copied from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.scale.float()) + return output.type_as(x) diff --git a/training/torchtune/models/gemma/transformer.py b/training/torchtune/models/gemma/transformer.py new file mode 100644 index 0000000..bdc5778 --- /dev/null +++ b/training/torchtune/models/gemma/transformer.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torchtune.modules import KVCache + +from torchtune.modules.transformer import _get_clones, TransformerDecoderLayer + + +class GemmaTransformerDecoder(nn.Module): + """ + GemmaTransformer Decoder derived from Gemma architecture. A key difference between + the Gemma transformer decoder and :class:`~torchtune.modules.TransformerDecoder` + is that the output projection is replaced instead with a reverse projection + using the transposed token embedding weights from output dim to input dim + (see https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/modeling/reversible_embedding.py#L21). + + Args: + tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move + tokens to an embedding space and as the output projection. + layer (TransformerDecoderLayer): Transformer Decoder layer. + num_layers (int): Number of Transformer Decoder layers. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value. This is used to setup the + :func:`~torchtune.modules.KVCache` + head_dim (int): embedding dimension for each head in self-attention. This is used + to setup the :func:`~torchtune.modules.KVCache` + norm (nn.Module): Callable that applies normalization to the output of the decoder, + before final MLP. + norm_embeddings (bool): Whether to normalize the embeddings before passing them + through the decoder layers. Defaults to False. + + Note: + Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) + in the module where they are used. This helps reduces the number of raise + statements in code and improves readability. + """ + + def __init__( + self, + tok_embeddings: nn.Embedding, + layer: TransformerDecoderLayer, + num_layers: int, + max_seq_len: int, + num_heads: int, + head_dim: int, + norm: nn.Module, + norm_embeddings: bool = False, + ) -> None: + super().__init__() + self.tok_embeddings = tok_embeddings + self.layers = _get_clones(layer, num_layers) + self.norm = norm + self.max_seq_len = max_seq_len + self.num_heads = num_heads + self.head_dim = head_dim + self.causal_mask = None + self.norm_embeddings = norm_embeddings + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + for layer in self.layers: + layer.attn.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=self.max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + # causal_mask is used during inference to ensure we're attending + # to the right tokens + self.causal_mask = torch.tril( + torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) + ) + + def forward( + self, + tokens: Tensor, + *, + mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + tokens (Tensor): input tensor with shape [b x s] + mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + with shape [b x s x s]. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + Tensor: output tensor with shape [b x s x v] + + Raises: + ValueError: if causal_mask is set but input_pos is None + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - v: vocab size + - d: embed dim + - m_s: max seq len + """ + # input tensor of shape [b, s] + bsz, seq_len = tokens.shape + + # shape: [b, s, d] + h = self.tok_embeddings(tokens) + + if self.causal_mask is not None: + if input_pos is None: + raise ValueError( + "Caches are setup, but the position of input token is missing" + ) + if mask is not None: + raise ValueError( + "An attention mask was set. Cannot use a non-causal mask for inference" + ) + # shape: [1, input_pos_len, m_s] + # in most cases input_pos_len should be 1 + mask = self.causal_mask[None, input_pos] + + if self.norm_embeddings: + hidden_dim = h.size(-1) + h = h * torch.tensor(hidden_dim**0.5, dtype=h.dtype) + + for layer in self.layers: + # shape: [b, s, d] + h = layer(h, mask=mask, input_pos=input_pos) + + # shape: [b, s, d] + h = self.norm(h) + + # shape: [b, s, v] + output = F.linear(h, self.tok_embeddings.weight).float() + return output diff --git a/training/torchtune/models/llama2/__init__.py b/training/torchtune/models/llama2/__init__.py new file mode 100644 index 0000000..a0a0f7d --- /dev/null +++ b/training/torchtune/models/llama2/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import ( + llama2, + llama2_classifier, + lora_llama2, + lora_llama2_classifier, +) + +from ._model_builders import ( # noqa + llama2_13b, + llama2_70b, + llama2_7b, + llama2_reward_7b, + llama2_tokenizer, + lora_llama2_13b, + lora_llama2_70b, + lora_llama2_7b, + lora_llama2_reward_7b, + qlora_llama2_13b, + qlora_llama2_70b, + qlora_llama2_7b, + qlora_llama2_reward_7b, +) +from ._prompt_template import Llama2ChatTemplate +from ._tokenizer import Llama2Tokenizer + +__all__ = [ + "Llama2Tokenizer", + "Llama2ChatTemplate", + "llama2", + "llama2_classifier", + "lora_llama2_classifier", + "llama2_reward_7b", + "lora_llama2_reward_7b", + "qlora_llama2_reward_7b", + "lora_llama2", + "llama2_13b", + "llama2_70b", + "llama2_7b", + "llama2_tokenizer", + "lora_llama2", + "llama2_classifier", + "lora_llama2_13b", + "lora_llama2_70b", + "lora_llama2_7b", + "qlora_llama2_13b", + "qlora_llama2_70b", + "qlora_llama2_7b", +] diff --git a/training/torchtune/models/llama2/_component_builders.py b/training/torchtune/models/llama2/_component_builders.py new file mode 100644 index 0000000..fda6471 --- /dev/null +++ b/training/torchtune/models/llama2/_component_builders.py @@ -0,0 +1,647 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Optional +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torch import nn + +from torchtune.models.llama2._model_utils import scale_hidden_dim_for_mlp + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + FrozenNF4Linear, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Llama2 model and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +# ------------------ Vanilla Llama2 ------------------ + + +def llama2( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + intermediate_dim: Optional[int] = None, + norm_eps: float = 1e-5, +) -> TransformerDecoder: + """ + Build the decoder associated with the Llama2 model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + + Returns: + TransformerDecoder: Instantiation of Llama2 model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def llama2_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: + """ + Build the MLP layer associated with the Llama model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +# ------------------ LoRA Llama2 ------------------ + + +def lora_llama2( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # llama2 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Llama2 model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_llama2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + if apply_lora_to_mlp: + mlp = lora_llama2_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model + + +def lora_llama2_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + """ + Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + CausalSelfAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules") + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + LoRALinear( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_llama2_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) + + +# ------------------ Llama2 Classifier ------------------ + + +def llama2_classifier( + num_classes: int, + *, + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + intermediate_dim: Optional[int] = None, + norm_eps: float = 1e-5, +) -> TransformerDecoder: + """ + Build a base Llama2 model with the final projection replaced with a classification layer. + + Args: + num_classes (int): number of classes for classification. + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. If specified, + user should ensure `num_heads` % `num_kv_heads` == 0. Default value is + `None`, in which case this is the same as MHA + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + + Returns: + TransformerDecoder: Instantiation of Llama2 model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, num_classes, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def lora_llama2_classifier( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # llama2 classifier args, + num_classes: int, + # llama2 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + num_classes (int): number of classes for classification. + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Llama2 model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_llama2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + if apply_lora_to_mlp: + mlp = lora_llama2_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, num_classes, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model diff --git a/training/torchtune/models/llama2/_model_builders.py b/training/torchtune/models/llama2/_model_builders.py new file mode 100644 index 0000000..1c8c7d0 --- /dev/null +++ b/training/torchtune/models/llama2/_model_builders.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional +from functools import partial + +from torchtune.models.llama2._component_builders import llama2, lora_llama2, llama2_classifier, lora_llama2_classifier + +from torchtune.modules import TransformerDecoder +from torchtune.models.llama2._tokenizer import Llama2Tokenizer +from torchtune.modules.peft import LORA_ATTN_MODULES + + +""" +Model builders build specific instantiations using component builders. For example +the llama2_7b model builder uses the llama2 component builder to create the +llama2 7B model. +""" + + +def llama2_7b() -> TransformerDecoder: + """ + Builder for creating a Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/abs/2307.09288 + + Returns: + TransformerDecoder: Instantiation of Llama2 7B model + """ + return llama2( + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def llama2_tokenizer(path: str, max_seq_len: Optional[int] = None) -> Llama2Tokenizer: + """ + Tokenizer for Llama2. + + Args: + path (str): path to the tokenizer + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + + Returns: + Llama2Tokenizer: Instantiation of the Llama2 tokenizer + """ + return Llama2Tokenizer(path=path, max_seq_len=max_seq_len) + + +def lora_llama2_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama2 7B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + + Returns: + TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_llama2_7b = partial(lora_llama2_7b, quantize_base=True) + +qlora_llama2_7b.__doc__ = """ +Builder for creating a Llama2 7B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_7b` for full API arguments. +""" + + +def llama2_13b() -> TransformerDecoder: + """ + Builder for creating a Llama2 model initialized w/ the default 13B parameter values + from https://arxiv.org/abs/2307.09288 + + Returns: + TransformerDecoder: Instantiation of Llama2 13B model + """ + return llama2( + vocab_size=32_000, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_llama2_13b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama2 13B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama2 13B model with LoRA applied + """ + + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_000, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_llama2_13b = partial(lora_llama2_13b, quantize_base=True) +qlora_llama2_13b.__doc__ = """ +Builder for creating a Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_13b` for full API arguments. +""" + + +def llama2_70b() -> TransformerDecoder: + """ + Builder for creating a Llama2 model initialized w/ the default 70B parameter values + from https://arxiv.org/abs/2307.09288 + + Returns: + TransformerDecoder: Instantiation of Llama2 70B model + """ + return llama2( + vocab_size=32_000, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_llama2_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama2 70B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_70b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for LoRA linear layers. Default: 0.05 + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama2 70B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_000, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + max_seq_len=4096, + intermediate_dim=28672, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_llama2_70b = partial(lora_llama2_70b, quantize_base=True) +qlora_llama2_70b.__doc__ = """ +Builder for creating a Llama2 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_70b` for full API arguments. +""" + + +def llama2_reward_7b() -> TransformerDecoder: + """ + Builder for creating a Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/abs/2307.09288, where the output layer is a classification layer + projecting to a single class for reward modelling. + + Returns: + TransformerDecoder: Instantiation of Llama2 7B model + """ + return llama2_classifier( + num_classes=1, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_llama2_reward_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama2 7B reward model with LoRA enabled. + + The Llama2 classifier defaults are the same as in :func:`~torchtune.models.llama2.llama2_reward_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied + """ + return lora_llama2_classifier( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + num_classes=1, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_llama2_reward_7b = partial(lora_llama2_7b, quantize_base=True) +qlora_llama2_reward_7b.__doc__ = """ +Builder for creating a Llama2 reward 7b model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_reward_7b` for full API arguments. +""" diff --git a/training/torchtune/models/llama2/_model_utils.py b/training/torchtune/models/llama2/_model_utils.py new file mode 100644 index 0000000..010c1bc --- /dev/null +++ b/training/torchtune/models/llama2/_model_utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int: + """Scale hidden dimension for MLP to keep number of parameters and computation constant. + + Args: + dim (int): Input dimension. + multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation. + + Returns: + Scaled hidden dimension. + """ + # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of + # parameters and computation constant + hidden_dim = 4 * int(2 * dim / 3) + # Round hidden dimension to nearest multiple of `multiple_of` + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim diff --git a/training/torchtune/models/llama2/_prompt_template.py b/training/torchtune/models/llama2/_prompt_template.py new file mode 100644 index 0000000..519c162 --- /dev/null +++ b/training/torchtune/models/llama2/_prompt_template.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from torchtune.data import Message, PromptTemplateInterface + + +class Llama2ChatTemplate(PromptTemplateInterface): + """ + Prompt template that formats chat data of human and system prompts with appropriate tags + used in Llama2 pre-training. Taken from Meta's official `Llama inference + repository `_. + + .. code-block:: text + + "[INST] <> + You are a helpful, respectful and honest assistant. + <>" + + I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture..." + + + """ + + template = { + "system": ("<>\n", "\n<>\n\n"), + "user": ("[INST] ", " [/INST] "), + "assistant": ("", ""), + "ipython": ("", ""), + } + + def __call__( + self, + messages: List[Message], + ) -> List[Message]: + """ + Format user and system messages with appropriate tags. + + Args: + messages (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + """ + system_message = [] + formatted_dialogue = [] + for message in messages: + if message.role == "system": + system_message = ( + [{"type": "text", "content": self.template["system"][0]}] + + message.content + + [{"type": "text", "content": self.template["system"][1]}] + ) + # Incorporate the system message in the user message - Llama2 only + # looks for the <> tags and not the explicit role so this will + # be treated the same as an actual system message. We do this because + # of the nesting of the system prompt in the user message. + continue + elif message.role == "user": + content = ( + [{"type": "text", "content": self.template["user"][0]}] + + system_message + + message.content + + [{"type": "text", "content": self.template["user"][1]}] + ) + elif message.role == "assistant": + # No special formatting needed for assistant message + content = message.content + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue diff --git a/training/torchtune/models/llama2/_tokenizer.py b/training/torchtune/models/llama2/_tokenizer.py new file mode 100644 index 0000000..71f4ac2 --- /dev/null +++ b/training/torchtune/models/llama2/_tokenizer.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +from torchtune.data import Message +from torchtune.modules.tokenizers import ( + ModelTokenizer, + SentencePieceBaseTokenizer, + tokenize_messages_no_special_tokens, +) +from torchtune.modules.transforms import Transform + +WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] + + +class Llama2Tokenizer(ModelTokenizer, Transform): + """ + Llama2's implementation of the SentencePiece tokenizer. Llama2Tokenizer does + not include any additional special tokens. The prompt template described in + https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2/ describes + [INST][/INST] and <><> as special tokens but these are not registered + as unique ids and are tokenized as normal text. When using this tokenizer on the + pre-trained model for inference, it is strongly encouraged to apply the + :class:`~torchtune.data.Llama2ChatFormat` to your data beforehand to add the + [INST] and <> for optimal performance. For fine-tuning, this is not required. + For more details, see https://pytorch.org/torchtune/main/tutorials/chat.html#tokenizing-prompt-templates-special-tokens. + + Args: + path (str): Path to pretrained SentencePiece tokenizer file. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + + Examples: + >>> tokenizer = Llama2Tokenizer("/path/to/spm_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + max_seq_len: Optional[int] = None, + ): + self._spm_model = SentencePieceBaseTokenizer(path) + + # Original tokenizer has no pad_id, which causes indexing errors when batch training + self._spm_model.pad_id = 0 + + # During generation, stop when eos_id is encountered + self.stop_tokens = [self.eos_id] + + self.max_seq_len = max_seq_len + + @property + def eos_id(self): + return self._spm_model.eos_id + + @property + def bos_id(self): + return self._spm_model.bos_id + + @property + def pad_id(self): + return self._spm_model.pad_id + + @property + def vocab_size(self): + return self._spm_model.vocab_size + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + trim_leading_whitespace: bool = False, + ) -> List[int]: + return self._spm_model.encode( + text, + add_bos=add_bos, + add_eos=add_eos, + trim_leading_whitespace=trim_leading_whitespace, + ) + + def decode( + self, + token_ids: List[int], + ) -> str: + return self._spm_model.decode(token_ids) + + def tokenize_messages( + self, + messages: List[Message], + ) -> Tuple[List[int], List[bool]]: + r"""Tokenize a list of messages one at a time then concatenate them, + returning a list of tokens and a list of masks. + + Note: + sentencepiece has problems where in general + encode(s1 + s2) != encode(s1) + encode(s2) due to whitespace handling. + We can get around this by prepending s2 with a known token and slicing the + beginning off the tokenized s2. + + Example: + >>> tokenizer = Llama2Tokenizer(tokenizer_path, max_seq_len) + >>> messages = [ + Message(role="system", content="system message\n", masked=True), + Message(role="user", content="user prompt\n", masked=True), + Message(role="assistant", content="assistant response\n"), + ] + + >>> # tokenize_messages encodes messages separately and concats + >>> tokenizer.tokenize_messages(messages)[0] + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + >>> # Same result as encoding the full string in one go + >>> tokenizer.encode(''.join([message.content for message in messages])) + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + Args: + messages (List[Message]): A list of messages, each containing role, content, + and masked attributes. + + Returns: + Tuple[List[int], List[bool]]: The tokenized messages + """ + return tokenize_messages_no_special_tokens( + tokenizer=self, + messages=messages, + bos_id=self.bos_id, + eos_id=self.eos_id, + max_seq_len=self.max_seq_len, + ) + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/training/torchtune/models/llama3/__init__.py b/training/torchtune/models/llama3/__init__.py new file mode 100644 index 0000000..90de8c2 --- /dev/null +++ b/training/torchtune/models/llama3/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import llama3, lora_llama3 + +from ._model_builders import ( # noqa + llama3_70b, + llama3_8b, + llama3_tokenizer, + lora_llama3_70b, + lora_llama3_8b, + qlora_llama3_70b, + qlora_llama3_8b, +) +from ._tokenizer import Llama3Tokenizer + +__all__ = [ + "Llama3Tokenizer", + "llama3", + "llama3_8b", + "llama3_70b", + "llama3_tokenizer", + "lora_llama3", + "lora_llama3_8b", + "lora_llama3_70b", + "qlora_llama3_8b", + "qlora_llama3_70b", +] diff --git a/training/torchtune/models/llama3/_component_builders.py b/training/torchtune/models/llama3/_component_builders.py new file mode 100644 index 0000000..155bb5f --- /dev/null +++ b/training/torchtune/models/llama3/_component_builders.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Literal, Optional + +from torch import nn + +from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + FrozenNF4Linear, + KVCache, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) + +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Llama3 model and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +# ------------------ Vanilla Llama3 ------------------ + +def llama3( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 500000.0, + intermediate_dim: Optional[int] = None, + norm_eps: float = 1e-5, +) -> TransformerDecoder: + """ + Build the decoder associated with the Llama3 model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + + Returns: + TransformerDecoder: Instantiation of Llama3 model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + +def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: + """ + Build the MLP layer associated with the Llama model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + + +# ------------------ LoRA Llama3 ------------------ + + +def lora_llama3( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # llama3 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 500000.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Llama3 (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Llama3 model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_llama3_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + if apply_lora_to_mlp: + mlp = lora_llama3_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) + ) + + return model + + +def lora_llama3_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: float = 500000.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + """ + Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + CausalSelfAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + LoRALinear( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_llama3_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) diff --git a/training/torchtune/models/llama3/_model_builders.py b/training/torchtune/models/llama3/_model_builders.py new file mode 100644 index 0000000..a6dc9e0 --- /dev/null +++ b/training/torchtune/models/llama3/_model_builders.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional +from functools import partial + +from torchtune.models.llama3._component_builders import llama3, lora_llama3 + +from torchtune.modules import TransformerDecoder +from torchtune.models.llama3._tokenizer import Llama3Tokenizer +from torchtune.modules.peft import LORA_ATTN_MODULES +from torchtune.modules.tokenizers import parse_hf_tokenizer_json + + +""" +Model builders build specific instantiations using component builders. For example +the llama3_8b model builder uses the llama3 component builder to create the +Llama3 8B model. +""" + + +def llama3_8b() -> TransformerDecoder: + """ + Builder for creating a Llama3 model initialized w/ the default 8b parameter values. + + Returns: + TransformerDecoder: Instantiation of Llama3 8B model + """ + return llama3( + vocab_size=128_256, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + max_seq_len=8192, + intermediate_dim=14336, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + ) + + +def llama3_70b() -> TransformerDecoder: + """ + Builder for creating a Llama3 model initialized w/ the default 70B parameter values. + + Returns: + TransformerDecoder: Instantiation of Llama3 70 model + """ + return llama3( + vocab_size=128_256, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + max_seq_len=8192, + intermediate_dim=28672, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + ) + + +def llama3_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None) -> Llama3Tokenizer: + """ + Tokenizer for Llama3. + + Args: + path (str): path to the tokenizer + special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face + model files that contains all registered special tokens, or a local json file + structured similarly. Default is None to use the canonical Llama3 special tokens. + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + + Returns: + Llama3Tokenizer: Instantiation of the Llama3 tokenizer + """ + special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None + return Llama3Tokenizer(path=path, special_tokens=special_tokens, max_seq_len=max_seq_len) + + +def lora_llama3_8b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama3 8B model with LoRA enabled. + + The Llama3 defaults are the same as in :func:`~torchtune.models.llama3.llama3_8b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama3 8B model with LoRA applied + """ + return lora_llama3( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=128_256, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + max_seq_len=8192, + intermediate_dim=14336, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +def lora_llama3_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama3 70B model with LoRA enabled. + + The Llama3 defaults are the same as in :func:`~torchtune.models.llama3.llama3_70b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama3 70B model with LoRA applied + """ + return lora_llama3( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=128_256, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + max_seq_len=8192, + intermediate_dim=28672, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +qlora_llama3_8b = partial(lora_llama3_8b, quantize_base=True) + +qlora_llama3_8b.__doc__ = """ +Builder for creating a Llama3 8B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama3_8b` for full API arguments. +""" + +qlora_llama3_70b = partial(lora_llama3_70b, quantize_base=True) + +qlora_llama3_70b.__doc__ = """ +Builder for creating a Llama3 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama3_70b` for full API arguments. +""" diff --git a/training/torchtune/models/llama3/_model_utils.py b/training/torchtune/models/llama3/_model_utils.py new file mode 100644 index 0000000..010c1bc --- /dev/null +++ b/training/torchtune/models/llama3/_model_utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int: + """Scale hidden dimension for MLP to keep number of parameters and computation constant. + + Args: + dim (int): Input dimension. + multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation. + + Returns: + Scaled hidden dimension. + """ + # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of + # parameters and computation constant + hidden_dim = 4 * int(2 * dim / 3) + # Round hidden dimension to nearest multiple of `multiple_of` + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim diff --git a/training/torchtune/models/llama3/_tokenizer.py b/training/torchtune/models/llama3/_tokenizer.py new file mode 100644 index 0000000..231ffb7 --- /dev/null +++ b/training/torchtune/models/llama3/_tokenizer.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Mapping, Optional, Tuple + +from torchtune.data import Message, truncate +from torchtune.modules.tokenizers import ModelTokenizer, TikTokenBaseTokenizer +from torchtune.modules.transforms import Transform + + +CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa + +SPECIAL_TOKENS = { + "<|begin_of_text|>": 128000, + "<|end_of_text|>": 128001, + "<|reserved_special_token_0|>": 128002, + "<|reserved_special_token_1|>": 128003, + "<|finetune_right_pad_id|>": 128004, + "<|step_id|>": 128005, + "<|start_header_id|>": 128006, + "<|end_header_id|>": 128007, + "<|eom_id|>": 128008, + "<|eot_id|>": 128009, + "<|python_tag|>": 128010, + "<|image|>": 128011, + "<|video|>": 128012, +} + +NUM_RESERVED_SPECIAL_TOKENS = 256 + +RESERVED_TOKENS = { + f"<|reserved_special_token_{2 + i}|>": 128013 + i + for i in range(NUM_RESERVED_SPECIAL_TOKENS - len(SPECIAL_TOKENS)) +} + +LLAMA3_SPECIAL_TOKENS = {**SPECIAL_TOKENS, **RESERVED_TOKENS} + + +class Llama3Tokenizer(ModelTokenizer, Transform): + """ + tiktoken tokenizer configured with Llama3 Instruct's special tokens, as described in + https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 + + Args: + path (str): Path to pretrained tiktoken tokenizer file. + special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and + their registered token IDs. If left as None, this will be set to the canonical + Llama3 special tokens. + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + + Examples: + >>> tokenizer = Llama3Tokenizer("/path/to/tt_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + special_tokens: Optional[Dict[str, int]] = None, + max_seq_len: Optional[int] = None, + ): + self.special_tokens = ( + special_tokens if special_tokens is not None else LLAMA3_SPECIAL_TOKENS + ) + + self._validate_special_tokens() + + # Encode BOS and EOS, define pad ID + self.bos_id = self.special_tokens["<|begin_of_text|>"] + self.eos_id = self.special_tokens["<|end_of_text|>"] + self.pad_id = self.special_tokens["<|finetune_right_pad_id|>"] + self.step_id = self.special_tokens["<|step_id|>"] + + # Encode extra special tokens + self.start_header_id = self.special_tokens["<|start_header_id|>"] + self.end_header_id = self.special_tokens["<|end_header_id|>"] + self.eot_id = self.special_tokens["<|eot_id|>"] + + self.eom_id = self.special_tokens["<|eom_id|>"] + self.python_tag = self.special_tokens["<|python_tag|>"] + + # Media tokens + self.image_id = self.special_tokens["<|image|>"] + + # During generation, stop when either eos_id or eot_id is encountered + self.stop_tokens = [self.eos_id, self.eot_id] + + self.tt_model = TikTokenBaseTokenizer( + path=path, + name="llama3_tiktoken", + pattern=CL100K_PATTERN, + bos_id=self.bos_id, + eos_id=self.eos_id, + special_tokens=self.special_tokens, + ) + self.max_seq_len = max_seq_len + + def _validate_special_tokens( + self, + ): + """ + Validate that required special tokens are passed into the tokenizer. + """ + for token in [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", + "<|eot_id|>", + "<|python_tag|>", + ]: + if token not in self.special_tokens: + raise ValueError(f"{token} missing from special_tokens") + + @property + def base_vocab_size(self) -> int: + return self.tt_model.base_vocab_size + + @property + def vocab_size(self) -> int: + return self.tt_model.vocab_size + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + ) -> List[int]: + return self.tt_model.encode(text=text, add_bos=add_bos, add_eos=add_eos) + + def decode( + self, + token_ids: List[int], + truncate_at_eos: bool = True, + skip_special_tokens: bool = True, + ) -> str: + """ + Decode a list of token ids into a string. + + Args: + token_ids (List[int]): The list of token ids. + truncate_at_eos (bool): Whether to truncate the string at the end of + sequence token. Default is True. + skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string. + Default is True. + + Returns: + str: The decoded string. + """ + return self.tt_model.decode( + token_ids, + truncate_at_eos=truncate_at_eos, + skip_special_tokens=skip_special_tokens, + ) + + def _tokenize_header(self, message: Message) -> List[int]: + """ + Tokenize header start, message role, and header end as list of ids + """ + return ( + [self.start_header_id] + + self.encode(message.role.strip(), add_bos=False, add_eos=False) + + [self.end_header_id] + + self.encode("\n\n", add_bos=False, add_eos=False) + ) + + def _tokenize_end(self, message: Message) -> List[int]: + """ + Add eot or eom id at the end of the message. + """ + return [self.eot_id] if message.eot else [self.eom_id] + + def _tokenize_body(self, message: Message) -> List[int]: + """ + Tokenize message content as list of ids + """ + tokenized_body = [] + for item in message.content: + if item["type"] == "text": + tokenized_body += self.encode( + item["content"].strip(), add_bos=False, add_eos=False + ) + elif item["type"] == "image": + tokenized_body += [self.image_id] + else: + raise RuntimeError(f"Unsupported message content type: {item['type']}") + + if message.ipython: + tokenized_body = [self.python_tag] + tokenized_body + + return tokenized_body + + def tokenize_message( + self, + message: Message, + tokenize_header: bool = True, + tokenize_end: bool = True, + ) -> List[int]: + """ + Tokenize a message into a list of token ids. + + Args: + message (Message): The message to tokenize. + tokenize_header (bool): Whether to prepend a tokenized header to the message. + tokenize_end (bool): Whether to append eot or eom id at the end of the message. + + Returns: + List[int]: The list of token ids. + """ + + tokenized_header = self._tokenize_header(message) if tokenize_header else [] + + tokenized_body = self._tokenize_body(message) + + tokenized_end = self._tokenize_end(message) if tokenize_end else [] + + tokenized_message = tokenized_header + tokenized_body + tokenized_end + + return tokenized_message + + def tokenize_messages( + self, + messages: List[Message], + add_eos: bool = True, + ) -> Tuple[List[int], List[bool]]: + """ + Tokenize a list of messages into a list of token ids and masks. + + Args: + messages (List[Message]): The list of messages to tokenize. + add_eos (bool): Wether to add the tokenizer's eos_id. Default True. + + Returns: + Tuple[List[int], List[bool]]: The list of token ids and the list of masks. + """ + tokens = [self.bos_id] + # bos and eos are always masked + mask = [True] + for message in messages: + tokenized_message = self.tokenize_message(message) + + tokens = tokens + tokenized_message + mask = mask + ([message.masked] * len(tokenized_message)) + if self.max_seq_len and len(tokens) >= self.max_seq_len: + break + + if add_eos: + tokens = tokens + [self.eos_id] + mask = mask + [True] + if self.max_seq_len: + tokens = truncate(tokens, self.max_seq_len, self.eos_id) + mask = truncate(mask, self.max_seq_len, True) + + return tokens, mask + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/training/torchtune/models/llama3_1/__init__.py b/training/torchtune/models/llama3_1/__init__.py new file mode 100644 index 0000000..e2bd01a --- /dev/null +++ b/training/torchtune/models/llama3_1/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import llama3_1, lora_llama3_1 + +from ._model_builders import ( # noqa + llama3_1_70b, + llama3_1_8b, + lora_llama3_1_70b, + lora_llama3_1_8b, + qlora_llama3_1_70b, + qlora_llama3_1_8b, +) + +__all__ = [ + "llama3_1", + "llama3_1_8b", + "llama3_1_70b", + "lora_llama3_1", + "lora_llama3_1_8b", + "lora_llama3_1_70b", + "qlora_llama3_1_8b", + "qlora_llama3_1_70b", +] diff --git a/training/torchtune/models/llama3_1/_component_builders.py b/training/torchtune/models/llama3_1/_component_builders.py new file mode 100644 index 0000000..ad97767 --- /dev/null +++ b/training/torchtune/models/llama3_1/_component_builders.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Literal, Optional + +from torch import nn + +from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + KVCache, + RMSNorm, + TransformerDecoder, + TransformerDecoderLayer, +) + +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Llama3.1 model and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +# ------------------ Vanilla Llama3 ------------------ + +def llama3_1( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 500000.0, + intermediate_dim: Optional[int] = None, + norm_eps: float = 1e-5, +) -> TransformerDecoder: + """ + Build the decoder associated with the Llama3.1 model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + + Returns: + TransformerDecoder: Instantiation of Llama3.1 model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + +def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Build the MLP layer associated with the Llama model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + + +# ------------------ LoRA Llama3 ------------------ + + +def lora_llama3_1( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # llama3 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 500000.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Llama3.1 (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Llama3.1 model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_llama3_1_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + if apply_lora_to_mlp: + mlp = lora_llama3_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) + ) + + return model + + +def lora_llama3_1_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: float = 500000.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + """ + Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + CausalSelfAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + output_proj = ( + LoRALinear( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else nn.Linear(embed_dim, embed_dim, bias=False) + ) + rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_llama3_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) diff --git a/training/torchtune/models/llama3_1/_model_builders.py b/training/torchtune/models/llama3_1/_model_builders.py new file mode 100644 index 0000000..714099a --- /dev/null +++ b/training/torchtune/models/llama3_1/_model_builders.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional +from functools import partial + +from torchtune.models.llama3_1._component_builders import llama3_1, lora_llama3_1 + +from torchtune.modules import TransformerDecoder +from torchtune.models.llama3._tokenizer import Llama3Tokenizer +from torchtune.modules.peft import LORA_ATTN_MODULES +from torchtune.modules.tokenizers import parse_hf_tokenizer_json + + +""" +Model builders build specific instantiations using component builders. For example +the llama3_1_8b model builder uses the llama3 component builder to create the +Llama3.1 8B model. +""" + + +def llama3_1_8b() -> TransformerDecoder: + """ + Builder for creating a Llama3.1 model initialized w/ the default 8b parameter values. + + Returns: + TransformerDecoder: Instantiation of Llama3.1 8B model + """ + return llama3_1( + vocab_size=128_256, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + max_seq_len=131072, + intermediate_dim=14336, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + ) + + +def llama3_1_70b() -> TransformerDecoder: + """ + Builder for creating a Llama3.1 model initialized w/ the default 70B parameter values. + + Returns: + TransformerDecoder: Instantiation of Llama3.1 70B model + """ + return llama3_1( + vocab_size=128_256, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + max_seq_len=131072, + intermediate_dim=28672, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + ) + + +def lora_llama3_1_8b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama3.1 8B model with LoRA enabled. + + The Llama3.1 defaults are the same as in :func:`~torchtune.models.llama3_1.llama3_1_8b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama3.1 8B model with LoRA applied + """ + return lora_llama3_1( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=128_256, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + max_seq_len=131072, + intermediate_dim=14336, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +def lora_llama3_1_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama3.1 70B model with LoRA enabled. + + The Llama3.1 defaults are the same as in :func:`~torchtune.models.llama3_1.llama3_1_70b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama3.1 70B model with LoRA applied + """ + return lora_llama3_1( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=128_256, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + max_seq_len=131072, + intermediate_dim=28672, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500000.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +qlora_llama3_1_8b = partial(lora_llama3_1_8b, quantize_base=True) + +qlora_llama3_1_8b.__doc__ = """ +Builder for creating a Llama3.1 8B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama3_1_8b` for full API arguments. +""" + +qlora_llama3_1_70b = partial(lora_llama3_1_70b, quantize_base=True) + +qlora_llama3_1_70b.__doc__ = """ +Builder for creating a Llama3.1 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama3_1_70b` for full API arguments. +""" diff --git a/training/torchtune/models/llama3_1/_position_embeddings.py b/training/torchtune/models/llama3_1/_position_embeddings.py new file mode 100644 index 0000000..8bd127d --- /dev/null +++ b/training/torchtune/models/llama3_1/_position_embeddings.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch + +from torch import nn, Tensor + + +class Llama3ScaledRoPE(nn.Module): + """ + This class implements Rotary Positional Embeddings (RoPE) + proposed in https://arxiv.org/abs/2104.09864 with additional + scaling from https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1. + + In this implementation we cache the embeddings for each position upto + ``max_seq_len`` by computing this during init. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ````embed_dim`` // ``num_heads```` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self.is_cache_built = False + + # We need to explicitly define reset_parameters for FSDP initialization, see + # https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885 + def reset_parameters(self): + self._rope_init() + + def _rope_init(self): + freqs = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + theta = self.apply_scaling(freqs) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + self.is_cache_built = True + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def apply_scaling(self, freqs: torch.Tensor): + """From the following Meta-Llama code: + https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1/api/model.py#L41""" + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [b, s, n_h, h_d] + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with RoPE applied + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + # TODO: Remove this hack for handling scaling for Meta device + if not self.is_cache_built: + with torch.device(x.device): + self._rope_init() + + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) diff --git a/training/torchtune/models/mistral/__init__.py b/training/torchtune/models/mistral/__init__.py new file mode 100644 index 0000000..cb8e859 --- /dev/null +++ b/training/torchtune/models/mistral/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import ( + lora_mistral, + lora_mistral_classifier, + mistral, + mistral_classifier, +) +from ._model_builders import ( + lora_mistral_7b, + lora_mistral_reward_7b, + mistral_7b, + mistral_reward_7b, + mistral_tokenizer, + qlora_mistral_7b, + qlora_mistral_reward_7b, +) +from ._prompt_template import MistralChatTemplate +from ._tokenizer import MistralTokenizer + +__all__ = [ + "MistralTokenizer", + "MistralChatTemplate", + "lora_mistral", + "lora_mistral_classifier", + "mistral", + "mistral_classifier", + "mistral_reward_hf_to_tune", + "mistral_reward_tune_to_hf", + "lora_mistral_7b", + "lora_mistral_reward_7b", + "mistral_7b", + "mistral_reward_7b", + "mistral_tokenizer", + "qlora_mistral_7b", + "qlora_mistral_reward_7b", +] diff --git a/training/torchtune/models/mistral/_component_builders.py b/training/torchtune/models/mistral/_component_builders.py new file mode 100644 index 0000000..e755e63 --- /dev/null +++ b/training/torchtune/models/mistral/_component_builders.py @@ -0,0 +1,639 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook +from typing import List + +from torch import nn + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + FrozenNF4Linear, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Mistral 7B models and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +def mistral( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, +) -> TransformerDecoder: + """ + Build the decoder associated with the mistral model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + This does NOT currently include inference-time optimizations such as + sliding-window attention + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + + Returns: + TransformerDecoder: Instantiation of mistral model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def mistral_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: + """ + Build the MLP layer associated with the Mistral model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +def lora_mistral( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # mistral args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Mistral (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with + intermediate_dim (int): intermediate dimension for MLP. + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms. + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Mistral model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_mistral_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_mistral_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + else: + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model + + +def lora_mistral_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 10_000, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + """ + Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + CausalSelfAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules") + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + LoRALinear( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) + ) + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_mistral_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) + + +def mistral_classifier( + num_classes: int, + *, + # base mistral args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, +) -> TransformerDecoder: + """ + Build a base mistral model with an added classification layer. + See :func:`~torchtune.models.mistral.mistral_classifier` + for details on the base mistral classifier model. + + Args: + num_classes (int): number of classes for the classification layer. + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + + Returns: + TransformerDecoder: Instantiation of mistral classification model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, num_classes, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def lora_mistral_classifier( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # mistral classifier args + num_classes: int, + # mistral args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Mistral classifier (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied to some of the linear layers in its self-attention modules. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + num_classes (int): number of classes for the classification layer. + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with + intermediate_dim (int): intermediate dimension for MLP. + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms. + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Mistral classifier model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_mistral_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_mistral_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + else: + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, num_classes, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model diff --git a/training/torchtune/models/mistral/_model_builders.py b/training/torchtune/models/mistral/_model_builders.py new file mode 100644 index 0000000..ebe56fa --- /dev/null +++ b/training/torchtune/models/mistral/_model_builders.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional + +from torchtune.models.mistral._component_builders import ( + mistral, + lora_mistral, + mistral_classifier, + lora_mistral_classifier, +) + +from torchtune.modules import TransformerDecoder +from torchtune.models.mistral._tokenizer import MistralTokenizer +from torchtune.modules.peft import LORA_ATTN_MODULES +from functools import partial + + +""" +Model builders build specific instantiations using component builders. For example +the ``mistral_7b`` model builder uses the ``mistral`` component builder. +""" + + +def mistral_7b() -> TransformerDecoder: + """ + Builder for creating a Mistral 7B model initialized w/ the default 7b parameter values + from https://mistral.ai/news/announcing-mistral-7b/ + + + Returns: + TransformerDecoder: Instantiation of Mistral 7B model + """ + return mistral( + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + intermediate_dim=14336, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def mistral_tokenizer(path: str, max_seq_len: Optional[int] = None) -> MistralTokenizer: + """ + Tokenizer for Mistral models. + + Args: + path (str): path to the tokenizer + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + + Returns: + MistralTokenizer: Instantiation of the Mistral tokenizer + """ + return MistralTokenizer(path=path, max_seq_len=max_seq_len) + + +def lora_mistral_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Mistral 7B model with LoRA enabled. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Mistral 7B model with LoRA applied + """ + return lora_mistral( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + intermediate_dim=14336, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=10_000, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +qlora_mistral_7b = partial(lora_mistral_7b, quantize_base=True) + +qlora_mistral_7b.__doc__ = """ +Builder for creating a Mistral model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_mistral_7b` for full API arguments. +""" + + +def mistral_reward_7b() -> TransformerDecoder: + """ + Builder for creating a Mistral 7B model initialized w/ the default 7b + parameter values from: + https://huggingface.co/Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback + where the output layer is a classification layer projecting to a single class for reward modelling. + + Returns: + TransformerDecoder: Instantiation of Mistral 7B classifier model + """ + return mistral_classifier( + num_classes=1, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + intermediate_dim=14336, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_mistral_reward_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Mistral reward 7B model with LoRA enabled. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Mistral 7B model with LoRA applied + """ + return lora_mistral_classifier( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + num_classes=1, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + intermediate_dim=14336, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=10_000, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +qlora_mistral_reward_7b = partial(lora_mistral_reward_7b, quantize_base=True) + +qlora_mistral_reward_7b.__doc__ = """ +Builder for creating a Mistral reward 7B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_mistral_reward_7b` for full API arguments. +""" diff --git a/training/torchtune/models/mistral/_prompt_template.py b/training/torchtune/models/mistral/_prompt_template.py new file mode 100644 index 0000000..93b364a --- /dev/null +++ b/training/torchtune/models/mistral/_prompt_template.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from torchtune.data import Message, PromptTemplateInterface + + +class MistralChatTemplate(PromptTemplateInterface): + """ + Formats according to Mistral's `instruct model + `_. + + It is identical to :class:`~torchtune.data.Llama2ChatTemplate`, except it does not support system + prompts. + + Note: + This template is only recommended for Mistral's Instruct-v0.1 and Instruct-v0.2 models. + Instruct-v0.3 adds additional tags for tool calls, which is not yet supported by this + template. + + .. code-block:: text + + "[INST] I am going to Paris, what should I see? [/INST] Paris, the capital + of France, is known for its stunning architecture..." + + """ + + template = { + "system": None, + "user": ("[INST] ", " [/INST] "), + "assistant": ("", ""), + "ipython": ("", ""), + } + + def __call__( + self, + messages: List[Message], + ) -> List[Message]: + """ + Format user and system messages with appropriate tags. + + Args: + messages (List[Message]): a single conversation, structured as a list + of `Message` objects + + Returns: + The formatted list of messages + + Raises: + ValueError: If system prompts are provided + """ + formatted_dialogue = [] + for message in messages: + if message.role == "system": + raise ValueError( + "System prompts are not supported in MistralChatTemplate" + ) + else: + content = ( + [{"type": "text", "content": self.template[message.role][0]}] + + message.content + + [{"type": "text", "content": self.template[message.role][1]}] + ) + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue diff --git a/training/torchtune/models/mistral/_tokenizer.py b/training/torchtune/models/mistral/_tokenizer.py new file mode 100644 index 0000000..7c63c17 --- /dev/null +++ b/training/torchtune/models/mistral/_tokenizer.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +from torchtune.data import Message +from torchtune.modules.tokenizers import ( + ModelTokenizer, + SentencePieceBaseTokenizer, + tokenize_messages_no_special_tokens, +) +from torchtune.modules.transforms import Transform + +WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] + + +class MistralTokenizer(ModelTokenizer, Transform): + """ + Mistral's implementation of the SentencePiece tokenizer + + Args: + path (str): Path to pretrained tokenizer file. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + + Examples: + >>> tokenizer = MistralTokenizer("/path/to/spm_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + max_seq_len: Optional[int] = None, + ): + self._spm_model = SentencePieceBaseTokenizer(path) + + # Original tokenizer has no pad_id, which causes indexing errors when batch training + self._spm_model.pad_id = 0 + + # During generation, stop when eos_id is encountered + self.stop_tokens = [self.eos_id] + + self.max_seq_len = max_seq_len + + @property + def eos_id(self): + return self._spm_model.eos_id + + @property + def bos_id(self): + return self._spm_model.bos_id + + @property + def pad_id(self): + return self._spm_model.pad_id + + @property + def vocab_size(self): + return self._spm_model.vocab_size + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + trim_leading_whitespace: bool = False, + ) -> List[int]: + """ + Encode a string into a list of token IDs + + Args: + text (str): The input text to be encoded, unbatched. + add_bos (bool): Whether to prepend BOS special token (Beginning of Sentence) to the input, defaults to True. + add_eos (bool): Whether to append EOS special token (End of Sentence) to the input, defaults to True. + trim_leading_whitespace (bool): Whether to trim leading whitespace from + underlying sentencepiece tokenization. Sentencepiece normally prepends + whitespace to any tokenized text, which can cause differences where + encode(s1) + encode(s2) != encode(s1 + s2) due to leading whitespace + added to s2. Default: False + Returns: + List[int]: The encoded token IDs. + """ + return self._spm_model.encode( + text, + add_bos=add_bos, + add_eos=add_eos, + trim_leading_whitespace=trim_leading_whitespace, + ) + + def decode( + self, + token_ids: List[int], + ) -> str: + """Decode token IDs to strings. + + Args: + token_ids (List[int]): The input token IDs to be decoded. + + Returns: + str: The decoded text. + """ + return self._spm_model.decode(token_ids) + + def tokenize_messages( + self, messages: List[Message] + ) -> Tuple[List[int], List[bool]]: + r"""Tokenize a list of messages one at a time then concatenate them, + returning a list of tokens and a list of masks. + + Note: + sentencepiece has problems where in general + encode(s1 + s2) != encode(s1) + encode(s2) due to whitespace handling. + We can get around this by prepending s2 with a known token and slicing the + beginning off the tokenized s2. + + Example: + >>> tokenizer = MistralTokenizer(tokenizer_path, max_seq_len) + >>> messages = [ + Message(role="system", content="system message\n", masked=True), + Message(role="user", content="user prompt\n", masked=True), + Message(role="assistant", content="assistant response\n"), + ] + + >>> # tokenize_messages encodes messages separately and concats + >>> tokenizer.tokenize_messages(messages)[0] + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + >>> # Same result as encoding the full string in one go + >>> tokenizer.encode(''.join([message.content for message in messages])) + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + Args: + messages (List[Message]): A list of messages, each containing role, content, + and masked attributes. + + Returns: + Tuple[List[int], List[bool]]: The tokenized messages + """ + return tokenize_messages_no_special_tokens( + tokenizer=self, + messages=messages, + bos_id=self.bos_id, + eos_id=self.eos_id, + max_seq_len=self.max_seq_len, + ) + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/training/torchtune/models/phi3/__init__.py b/training/torchtune/models/phi3/__init__.py new file mode 100644 index 0000000..114b3d4 --- /dev/null +++ b/training/torchtune/models/phi3/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import lora_phi3, phi3 # noqa +from ._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf # noqa +from ._model_builders import ( # noqa + lora_phi3_mini, + phi3_mini, + phi3_mini_tokenizer, + qlora_phi3_mini, +) +from ._position_embeddings import Phi3RotaryPositionalEmbeddings # noqa +from ._tokenizer import Phi3MiniTokenizer # noqa + +__all__ = [ + "phi3_mini", + "phi3_mini_tokenizer", + "lora_phi3_mini", + "qlora_phi3_mini", + "Phi3RotaryPositionalEmbeddings", + "Phi3MiniTokenizer", + "phi3_hf_to_tune", + "phi3_tune_to_hf", + "phi3", + "lora_phi3", +] diff --git a/training/torchtune/models/phi3/_component_builders.py b/training/torchtune/models/phi3/_component_builders.py new file mode 100644 index 0000000..3e84fe4 --- /dev/null +++ b/training/torchtune/models/phi3/_component_builders.py @@ -0,0 +1,424 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List + +from torch import nn + +from torchtune.models.phi3._position_embeddings import Phi3RotaryPositionalEmbeddings +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + FrozenNF4Linear, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) + +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Phi3 4K Mini Instruct model. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + +def phi3( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, +) -> TransformerDecoder: + """ + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + + Returns: + TransformerDecoder: Instantiation of Phi3 Mini 4K Instruct model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + +def phi3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: + """ + Build the MLP layer associated with the Phi3 Mini 4K Instruct model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +# ------------------ LoRA Phi3 ------------------ + + +def lora_phi3( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # phi3 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Phi3 (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms. + rope_base (int): base value for Rotary Position Embeddings. + Default: 10000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Llama3 model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_phi3_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_phi3_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial(reparametrize_as_dtype_state_dict_post_hook, + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True) + ) + + return model + + +def lora_phi3_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 10_000, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + """ + Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + rope_base (int): base value for Rotary Position Embeddings. + Default: 10000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + CausalSelfAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + LoRALinear( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) + ) + rope = Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_phi3_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) diff --git a/training/torchtune/models/phi3/_convert_weights.py b/training/torchtune/models/phi3/_convert_weights.py new file mode 100644 index 0000000..e7215e9 --- /dev/null +++ b/training/torchtune/models/phi3/_convert_weights.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + + +_PHI3_MINI = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + "lm_head.weight": "output.weight", +} + + +def phi3_hf_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convertor from HF state dict to torchtune state dict. This handles: + - Splitting the fused q,k and v matrix + - Splitting the fused gate and up projection matrix + """ + converted_state_dict = {} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _PHI3_MINI) + if "qkv" in key: + ( + q, + k, + v, + ) = value.chunk(3, dim=0) + converted_state_dict[new_key] = q + converted_state_dict[new_key.replace("q_proj", "k_proj")] = k + converted_state_dict[new_key.replace("q_proj", "v_proj")] = v + elif "gate" in key: + w1, w3 = value.chunk(2, dim=0) + converted_state_dict[new_key] = w1 + converted_state_dict[new_key.replace("w1", "w3")] = w3 + else: + converted_state_dict[new_key] = value + return converted_state_dict + + +def phi3_tune_to_hf(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convertor from torchtune state dict to HF state dict. This handles: + - Fusing q,k and v matrix + - Fusing gate and up projection matrix + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _PHI3_MINI.items()} + + for key, value in state_dict.items(): + if "k_proj" in key or "v_proj" in key or "w3" in key: + # these keys are accounted for separately and should be skipped + continue + new_key = get_mapped_key(key, inverted_mapping_dict) + + if "q_proj" in key: + q = value + k = state_dict[key.replace("q_proj", "k_proj")] + v = state_dict[key.replace("q_proj", "v_proj")] + qkv = torch.cat([q, k, v], dim=0) + # q_proj maps to qkv_proj; no need to string replace + converted_state_dict[new_key] = qkv + + elif "w1" in key: + gate_proj = value + up_proj = state_dict[key.replace("w1", "w3")] + gate_up_proj = torch.cat([gate_proj, up_proj], dim=0) + # w1 maps to gate_up_proj; no need to string replace + converted_state_dict[new_key] = gate_up_proj + + else: + converted_state_dict[new_key] = value + return converted_state_dict diff --git a/training/torchtune/models/phi3/_model_builders.py b/training/torchtune/models/phi3/_model_builders.py new file mode 100644 index 0000000..9208175 --- /dev/null +++ b/training/torchtune/models/phi3/_model_builders.py @@ -0,0 +1,122 @@ +from typing import List, Optional + +from torchtune.models.phi3._component_builders import phi3, lora_phi3 +from torchtune.models.phi3._tokenizer import Phi3MiniTokenizer + +from torchtune.modules import TransformerDecoder +from torchtune.modules.peft import LORA_ATTN_MODULES +from functools import partial +from torchtune.modules.tokenizers import parse_hf_tokenizer_json + + +""" +Model builders build specific instantiations using component builders. For example +the ``phi3_mini`` model builder uses the ``phi3`` component builder. +""" + + +def phi3_mini() -> TransformerDecoder: + """ + Builder for creating the Phi3 Mini 4K Instruct Model. + Ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + + Note: + This model does not currently support 128K context length nor optimizations + such as sliding window attention. + + Returns: + TransformerDecoder: Instantiation of Phi3 Mini 4K Instruct Model + """ + return phi3( + vocab_size=32_064, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=3072, + intermediate_dim=8192, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + ) + +def phi3_mini_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None) -> Phi3MiniTokenizer: + """Phi-3 Mini tokenizer. + Ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/tokenizer_config.json + + Args: + path (str): Path to the SPM tokenizer model. + special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face + model files that contains all registered special tokens, or a local json file + structured similarly. Default is None to use the canonical Phi3 special tokens. + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + + Note: + This tokenizer includes typical LM EOS and BOS tokens like + , , and . However, to support chat completion, + it is also augmented with special tokens like + and . + + Returns: + Phi3MiniSentencePieceBaseTokenizer: Instantiation of the SPM tokenizer. + """ + special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None + return Phi3MiniTokenizer(path=path, special_tokens=special_tokens, max_seq_len=max_seq_len) + + +def lora_phi3_mini( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Phi3 Mini (3.8b) model with LoRA enabled. + + The Phi3 defaults are the same as in :func:`~torchtune.models.phi3.phi3_mini`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Phi3 Mini model with LoRA applied + """ + return lora_phi3( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_064, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=3072, + intermediate_dim=8192, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=quantize_base, + ) + + +qlora_phi3_mini = partial(lora_phi3_mini, quantize_base=True) +qlora_phi3_mini.__doc__ = """ +Builder for creating a Phi3 mini model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_phi3_mini` for full API arguments. +""" diff --git a/training/torchtune/models/phi3/_position_embeddings.py b/training/torchtune/models/phi3/_position_embeddings.py new file mode 100644 index 0000000..9519f2d --- /dev/null +++ b/training/torchtune/models/phi3/_position_embeddings.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torch import nn, Tensor + + +class Phi3RotaryPositionalEmbeddings(nn.Module): + """ + RoPE Embeddings used in the Phi3 model. + Ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + + This class is not numerically equivalent to the RoPE Embedding module + used by Llama2 and Llama3. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim`` // ``num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self._rope_init() + + def _rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # We cache the cos and sin embeddings instead of the IDs. This helps + # ensure we have correct behavior when training with bf16 + # Size: [max_seq_len, (dim * 2)] + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [b, s, n_h, h_d] + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with RoPE applied + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + + TODO: The implementation below can be made more efficient + for inference. + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + head_dim = x.size(-1) + + # extract the values based on whether input_pos is set or not. When + # input_pos is provided, we're in inference mode + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d * 2] if packed samples, + # otherwise has shape [1, s, 1, h_d * 2] + rope_cache = rope_cache.view(-1, seq_len, 1, head_dim * 2) + + # [b, s, 1, h_d] + cos = rope_cache[..., :head_dim] + sin = rope_cache[..., head_dim:] + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + rotated = torch.cat((-x2, x1), dim=-1) + + # cos: [b, s, 1, h_d] + # x: [b, s, n_h, h_d] + x_out = (x * cos) + (rotated * sin) + return x_out.type_as(x) diff --git a/training/torchtune/models/phi3/_tokenizer.py b/training/torchtune/models/phi3/_tokenizer.py new file mode 100644 index 0000000..221db90 --- /dev/null +++ b/training/torchtune/models/phi3/_tokenizer.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Mapping, Optional, Tuple + +from torchtune.data import Message, truncate +from torchtune.modules.tokenizers import ModelTokenizer, SentencePieceBaseTokenizer +from torchtune.modules.transforms import Transform + +PHI3_SPECIAL_TOKENS = { + "<|endoftext|>": 32000, + "<|assistant|>": 32001, + "<|placeholder1|>": 32002, + "<|placeholder2|>": 32003, + "<|placeholder3|>": 32004, + "<|placeholder4|>": 32005, + "<|system|>": 32006, + "<|end|>": 32007, + "<|placeholder5|>": 32008, + "<|placeholder6|>": 32009, + "<|user|>": 32010, +} + + +class Phi3MiniTokenizer(ModelTokenizer, Transform): + """ + SentencePiece tokenizer configured with Phi3 Mini's special tokens. + + Args: + path (str): Path to pretrained tokenizer file. + special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and + their registered token IDs. If left as None, this will be set to the canonical + Phi3 special tokens. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + + Examples: + >>> tokenizer = Phi3MiniTokenizer("/path/to/spm_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + special_tokens: Optional[Dict[str, int]] = None, + max_seq_len: Optional[int] = None, + ): + self._spm_model = SentencePieceBaseTokenizer(path) + + self.special_tokens = ( + special_tokens if special_tokens is not None else PHI3_SPECIAL_TOKENS + ) + + # Use custom EOS and pad ids instead of SentencePiece's + self.eos_id = self.special_tokens["<|endoftext|>"] + self.pad_id = self.special_tokens["<|endoftext|>"] + + # During generation, stop when eos_id is encountered + self.stop_tokens = [self.eos_id] + + self.max_seq_len = max_seq_len + + @property + def vocab_size(self): + return self._spm_model.vocab_size + + @property + def bos_id(self): + return self._spm_model.bos_id + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + trim_leading_whitespace: bool = False, + ) -> List[int]: + return self._spm_model.encode( + text, + add_bos=add_bos, + add_eos=add_eos, + trim_leading_whitespace=trim_leading_whitespace, + ) + + def decode(self, ids: List[int]) -> str: + """Decode token IDs to strings. + + Args: + ids (List[int]): The input token IDs to be decoded. + + Returns: + str: The decoded text. + """ + ids_for_decode = [] + for token_id in ids: + # Filter out special tokens and the placeholder tokens added + # by the Phi3 team + if token_id >= 32_000 and token_id <= 32_064: + continue + else: + ids_for_decode.append(token_id) + return self._spm_model.decode(ids_for_decode) + + def tokenize_messages( + self, + messages: List[Message], + *, + add_eos: bool = False, + ignore_system_prompts: bool = True, + ) -> Tuple[List[int], List[bool]]: + r"""Tokenize a list of messages one at a time then concatenate them, + returning a list of tokens and a list of masks. + + Example: + >>> tokenizer = Phi3MiniTokenizer(tokenizer_path, max_seq_len) + >>> messages = [ + Message(role="system", content="system message\n", masked=True), + Message(role="user", content="user prompt\n", masked=True), + Message(role="assistant", content="assistant response\n"), + ] + + >>> # tokenize_messages encodes messages separately and concats + >>> tokenizer.tokenize_messages(messages)[0] + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + >>> # Same result as encoding the full string in one go + >>> tokenizer.encode(''.join([message.content for message in messages])) + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + Args: + messages (List[Message]): A list of messages, each containing role, content, + and masked attributes. + add_eos (bool): Whether to append EOS after assistant message, default to False + ignore_system_prompts (bool): Whether to ignore system prompts. This matches the HF implementation, default to True. + + Raises: + ValueError: If the role is not "user", "assistant", or "system". + + Returns: + Tuple[List[int], List[bool]]: The tokenized messages + """ + start_of_turn = True + end_of_turn = False + tokenized_messages = [] + mask = [] + + # The chat template in HF adds a bunch of newlines + new_line_token_id = self.encode("\n", add_bos=False, add_eos=False) + + for message in messages: + # Skip system prompt + if ignore_system_prompts and message.role == "system": + continue + + # Prepend BOS on start of new turns + if start_of_turn: + tokenized_messages.append(self.bos_id) + mask.append(message.masked) + + # Add special tokens + if message.role == "user": + tokenized_messages.append(self.special_tokens["<|user|>"]) + mask.append(message.masked) + elif message.role == "assistant": + tokenized_messages.append(self.special_tokens["<|assistant|>"]) + # If assistant message, this is the end of a turn + end_of_turn = True + mask.append(message.masked) + elif message.role == "system": + tokenized_messages.append(self.special_tokens["<|system|>"]) + mask.append(message.masked) + else: + raise ValueError( + f"Unknown role '{message.role}' for message: '{message.content}'" + ) + + # Add new line token + tokenized_messages.extend(new_line_token_id) + mask.extend([message.masked] * len(new_line_token_id)) + + # Tokenize current message, append with masks + tokens = [] + for item in message.content: + if item["type"] == "text": + tokens = tokens + self.encode( + item["content"].rstrip(" "), + add_bos=False, + add_eos=False, + trim_leading_whitespace=True, # Always trim whitespace (just to match HF tokenizer implementation) + ) + else: + raise RuntimeError( + f"Unsupported message content type: {item['type']}" + ) + + tokens = tokens + [self.special_tokens["<|end|>"]] + new_line_token_id + tokenized_messages.extend(tokens) + mask.extend([message.masked] * len(tokens)) + + # If assistant message, append EOS at end + if end_of_turn and add_eos: + tokenized_messages.append(self.eos_id) + mask.append(message.masked) + end_of_turn = False + start_of_turn = True + else: + start_of_turn = False + + # Break out early if we reach max_seq_len + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + break + + # Finally, truncate if necessary + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + tokenized_messages = truncate( + tokenized_messages, self.max_seq_len, self.eos_id + ) + mask = truncate(mask, self.max_seq_len, message.masked) + + return tokenized_messages, mask + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/training/torchtune/models/qwen2/__init__.py b/training/torchtune/models/qwen2/__init__.py new file mode 100644 index 0000000..a2f36a3 --- /dev/null +++ b/training/torchtune/models/qwen2/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import lora_qwen2, qwen2 # noqa +from ._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf # noqa +from ._model_builders import ( + lora_qwen2_0_5b, + lora_qwen2_1_5b, + lora_qwen2_7b, + qwen2_0_5b, + qwen2_1_5b, + qwen2_7b, + qwen2_tokenizer, +) +from ._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from ._tokenizer import Qwen2Tokenizer + +__all__ = [ + "qwen2_7b", + "qwen2_0_5b", + "qwen2_1_5b", + "qwen2_tokenizer", + "lora_qwen2_7b", + "lora_qwen2_0_5b", + "lora_qwen2_1_5b", + "qwen2", + "lora_qwen2", + "qwen2_hf_to_tune", + "qwen2_tune_to_hf", + "Qwen2RotaryPositionalEmbeddings", + "Qwen2Tokenizer", +] diff --git a/training/torchtune/models/qwen2/_component_builders.py b/training/torchtune/models/qwen2/_component_builders.py new file mode 100644 index 0000000..d89f5b9 --- /dev/null +++ b/training/torchtune/models/qwen2/_component_builders.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Union +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torch import nn + +from torchtune.modules.transformer import TransformerDecoder, TiedEmbeddingTransformerDecoder +from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + RMSNorm, + TransformerDecoderLayer, +) + + +from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Qwen2 model and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +def qwen2( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 1_000_000.0, + tie_word_embeddings: bool = False, +) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]: + """ + Build the decoder associated with the Qwen2 model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + rope_base (float): the base period of the RoPE embeddings. + tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied. + + Returns: + TransformerDecoder: Instantiation of Qwen2 model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = Qwen2RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=True), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = qwen2_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = None if tie_word_embeddings else nn.Linear(embed_dim, vocab_size, bias=False) + if output_proj is None: + return TiedEmbeddingTransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + ) + else: + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def qwen2_mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Build the MLP layer associated with the Qwen2 model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +def lora_qwen2( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # qwen2 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 1_000_000.0, + tie_word_embeddings: bool = False, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]: + """ + Return a version of Qwen2 (an instance of :func:`~torchtune.models.qwen2.transformer.Qwen2TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + rope_base (float): the base period of the RoPE embeddings. + tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Qwen2 model with LoRA applied to + a subset of the attention projections in each layer. + + Raises: + ValueError: if ``apply_lora_to_output`` and ``tie_word_embeddings``. + + """ + + self_attn = lora_qwen2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_qwen2_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = qwen2_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + if tie_word_embeddings: + if apply_lora_to_output: + raise ValueError( + "apply_lora_to_output is incompatible with tie_word_embeddings," + " as there would be no output to apply lora to!" + ) + output_proj = None + else: + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + if output_proj is None: + model = TiedEmbeddingTransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + ) + else: + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model + + +def lora_qwen2_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # CausalSelfAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: float = 1_000_000.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> CausalSelfAttention: + """ + Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + rope_base (float): the base period of the RoPE embeddings. Default: 1_000_000.0 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + CausalSelfAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules") + + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + q_proj = ( + LoRALinear( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=True, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else nn.Linear(embed_dim, num_heads * head_dim, bias=True) + ) + k_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=True, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True) + ) + v_proj = ( + LoRALinear( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=True, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True) + ) + output_proj = ( + LoRALinear( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else nn.Linear(embed_dim, embed_dim, bias=False) + ) + rope = Qwen2RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_qwen2_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, +) -> FeedForward: + gate_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) diff --git a/training/torchtune/models/qwen2/_convert_weights.py b/training/torchtune/models/qwen2/_convert_weights.py new file mode 100644 index 0000000..d23f674 --- /dev/null +++ b/training/torchtune/models/qwen2/_convert_weights.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +# state dict key mappings from HF's format to torchtune's format +_FROM_HF = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attn.q_proj.bias", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attn.k_proj.bias", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attn.v_proj.bias", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + "lm_head.weight": "output.weight", +} + + +QWEN2_TIED_KEY = "lm_head.weight" + + +def qwen2_hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, + tie_word_embeddings: bool = False, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to TorchTune's format, which contains the weights + of a Qwen2 model. + State dicts from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but may not load + output projection weights. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + tie_word_embeddings (bool): Whether the model's input and output word embeddings should be tied. + + Returns: + Dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + if head_dim is None: + head_dim = dim // num_heads + + for key, value in state_dict.items(): + if ( + tie_word_embeddings and QWEN2_TIED_KEY in key + ): # Skip loading the output projection weights + continue + if "rotary_emb.inv_freq" in key: # Skip loading the position embeddings + continue + + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value + return converted_state_dict + + +def qwen2_tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, + tie_word_embeddings: bool = False, +): + """ + Convert a state dict from torchtune's format to HF's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + tie_word_embeddings (bool): Whether the model's input and output word embeddings should be tied. + + Returns: + Dict[str, torch.Tensor]: State dict in HF's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} + + if head_dim is None: + head_dim = dim // num_heads + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/training/torchtune/models/qwen2/_model_builders.py b/training/torchtune/models/qwen2/_model_builders.py new file mode 100644 index 0000000..b1af75e --- /dev/null +++ b/training/torchtune/models/qwen2/_model_builders.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional, Union + +from torchtune.models.qwen2._component_builders import qwen2, lora_qwen2 +from torchtune.models.qwen2._tokenizer import Qwen2Tokenizer +from torchtune.modules import TransformerDecoder, TiedEmbeddingTransformerDecoder +from torchtune.modules.peft import LORA_ATTN_MODULES +from torchtune.modules.tokenizers import parse_hf_tokenizer_json + +""" +Model builders build specific instantiations using component builders. For example +the qwen2_7b model builder uses the qwen2 component builder to create the +qwen2 7B model. +""" + + +def qwen2_7b() -> TransformerDecoder: + """ + Builder for creating a Qwen2 model initialized w/ the default 7B parameter values + from https://huggingface.co/Qwen/Qwen2-7B-Instruct + + Returns: + TransformerDecoder: Instantiation of Qwen2 7B model + """ + return qwen2( + vocab_size=152064, + num_layers=28, + num_heads=28, + num_kv_heads=4, + embed_dim=3584, + intermediate_dim=18944, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-06, + rope_base=1000000.0, + ) + + +def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder: + """ + Builder for creating a Qwen2 model initialized w/ the default 0.5B parameter values + from https://huggingface.co/Qwen/Qwen2-0.5B-Instruct + + Returns: + TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model + + Note: + Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default + and returns an instance of `TiedEmbeddingTransformerDecoder`. + """ + return qwen2( + vocab_size=151936, + num_layers=24, + num_heads=14, + num_kv_heads=2, + embed_dim=896, + intermediate_dim=4864, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-06, + rope_base=1000000.0, + tie_word_embeddings=True, + ) + + +def qwen2_1_5b() -> TiedEmbeddingTransformerDecoder: + """ + Builder for creating a Qwen2 model initialized w/ the default 1.5B parameter values + from https://huggingface.co/Qwen/Qwen2-1.5B-Instruct + + Returns: + TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model + + Note: + Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default + and returns an instance of `TiedEmbeddingTransformerDecoder`. + """ + return qwen2( + vocab_size=151936, + num_layers=28, + num_heads=12, + num_kv_heads=2, + embed_dim=1536, + intermediate_dim=8960, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-06, + rope_base=1000000.0, + tie_word_embeddings=True, + ) + + +def qwen2_tokenizer( + path: str, + merges_file: str = None, + special_tokens_path: Optional[str] = None, + max_seq_len: Optional[int] = None, + **kwargs, +) -> Qwen2Tokenizer: + """ + Tokenizer for Qwen2. + + Args: + path (str): path to the vocab.json file. + merges_file (str): path to the merges.txt file. + special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face + model files that contains all registered special tokens, or a local json file + structured similarly. Default is None to use the canonical Qwen2 special tokens. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + Returns: + Qwen2Tokenizer: Instantiation of the Qwen2 tokenizer + """ + special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None + return Qwen2Tokenizer(path=path, merges_file=merges_file, special_tokens=special_tokens, max_seq_len=max_seq_len, **kwargs) + + +def lora_qwen2_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Qwen2 7B model with LoRA enabled. + + The Qwen2 defaults are the same as in :func:`~torchtune.models.qwen2.qwen2_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Qwen2 7B model with LoRA applied + """ + return lora_qwen2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=152064, + num_layers=28, + num_heads=28, + num_kv_heads=4, + embed_dim=3584, + intermediate_dim=18944, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-6, + rope_base=1000000.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +def lora_qwen2_0_5b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TiedEmbeddingTransformerDecoder: + """ + Builder for creating a Qwen2 0.5B model with LoRA enabled. + + The Qwen2 defaults are the same as in :func:`~torchtune.models.qwen2.qwen2_0_5b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied + + Note: + Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default + and returns an instance of `TiedEmbeddingTransformerDecoder`. + """ + return lora_qwen2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=False, + vocab_size=151936, + num_layers=24, + num_heads=14, + num_kv_heads=2, + embed_dim=896, + intermediate_dim=4864, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-6, + rope_base=1000000.0, + tie_word_embeddings=True, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +def lora_qwen2_1_5b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TiedEmbeddingTransformerDecoder: + """ + Builder for creating a Qwen2 1.5B model with LoRA enabled. + + The Qwen2 defaults are the same as in :func:`~torchtune.models.qwen2.qwen2_1_5b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied + + Note: + Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default + and returns an instance of `TiedEmbeddingTransformerDecoder`. + """ + return lora_qwen2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=False, + vocab_size=151936, + num_layers=28, + num_heads=12, + num_kv_heads=2, + embed_dim=1536, + intermediate_dim=8960, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-6, + rope_base=1000000.0, + tie_word_embeddings=True, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) diff --git a/training/torchtune/models/qwen2/_positional_embeddings.py b/training/torchtune/models/qwen2/_positional_embeddings.py new file mode 100644 index 0000000..6ea5322 --- /dev/null +++ b/training/torchtune/models/qwen2/_positional_embeddings.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torch import nn, Tensor + + +class Qwen2RotaryPositionalEmbeddings(nn.Module): + """ + RoPE Embeddings used in the Qwen2 model. + Ref: https://huggingface.co/Qwen/Qwen2-7B-Instruct + + This class is not numerically equivalent to the RoPE Embedding module + used by Llama2 and Llama3. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim`` // ``num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (float): The base for the geometric progression used to compute + the rotation angles + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: float = 1_000_000.0, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self._rope_init() + + def _rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # We cache the cos and sin embeddings instead of the IDs. This helps + # ensure we have correct behavior when training with bf16 + # Size: [max_seq_len, (dim * 2)] + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [b, s, n_h, h_d] + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with RoPE applied + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + + TODO: The implementation below can be made more efficient + for inference. + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + head_dim = x.size(-1) + + # extract the values based on whether input_pos is set or not. When + # input_pos is provided, we're in inference mode + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d * 2] if packed samples, + # otherwise has shape [1, s, 1, h_d * 2] + rope_cache = rope_cache.view(-1, seq_len, 1, head_dim * 2) + + # [b, s, 1, h_d] + cos = rope_cache[..., :head_dim].to(x.dtype) + sin = rope_cache[..., head_dim:].to(x.dtype) + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + rotated = torch.cat((-x2, x1), dim=-1) + + # cos: [b, s, 1, h_d] + # x: [b, s, n_h, h_d] + x_out = (x * cos) + (rotated * sin) + return x_out.type_as(x) diff --git a/training/torchtune/models/qwen2/_tokenizer.py b/training/torchtune/models/qwen2/_tokenizer.py new file mode 100644 index 0000000..42a8fac --- /dev/null +++ b/training/torchtune/models/qwen2/_tokenizer.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import json +import unicodedata +from functools import lru_cache +from typing import Dict, List, Optional, Tuple + +import regex as re + +from torchtune.data import Message, truncate +from torchtune.modules.tokenizers import ModelTokenizer + +PRETOKENIZE_REGEX = ( + r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + r"[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" +) + +QWEN2_SPECIAL_TOKENS = { + "<|endoftext|>": 151643, + "<|im_start|>": 151644, + "<|im_end|>": 151645, +} + + +ENDOFTEXT = "<|endoftext|>" +IM_START = "<|im_start|>" +IM_END = "<|im_end|>" + +DEFAULT_QWEN2_TOKENIZER_BPE_CACHE_SIZE = 151646 + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoid mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Qwen2Tokenizer(ModelTokenizer): + """This class construct a Qwen2 tokenizer, based on GPT-2 byte-level BPE tokenization. + + See . + + Args: + path (str): Path to vocab.json file. + merges_file (str): Path to merges.txt file. + merges.txt contains all BPE merge operations, and this file is required to split a single word into + byte-level BPE tokens. + special_tokens (Optional[Dict[str, int]]): Special tokens to add to the tokenizer. Default is None. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + errors (str): Paradigm to follow when decoding bytes to UTF-8. Defaults to "replace". + See [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (Optional[str]): The unknown token. A token that is not in the vocabulary cannot be converted + to an ID and is set to be this token instead. Defaults to ``<|endoftext|>``. + bos_token (Optional[str]): The beginning of sequence token. Defaults to None. + eos_token (str): The end of sequence token. Defaults to ``<|endoftext|>``. + pad_token (Optional[str]): The token used for padding. Defaults to ``<|endoftext|>``. + bpe_cache_size (int): BPE token cache size in Qwen2Tokenizer. + NOTE: large cache size will speed up tokenization, but the cache object will get really + large for long running processes (esp. for texts of language that do not use space between + word, e.g. Chinese); technically not a memory leak but appears as one. + By default, we set the cache size equals to size of the official Qwen2 tokenizer. + + Attributes: + system (str): Qwen2 system prompt. + user (str): Qwen2 user prompt. + assistant (str): Qwen2 assistant prompt. + assistant_for_generation (str): Qwen2 assistant prompt for generation. + + Example: + >>> tokenizer = Qwen2Tokenizer(path="/path/to/vocab.json", merges_file="/path/to/merges.txt") + >>> tokenized_text = tokenizer.encode("Hello world!") + >>> print(tokenized_text) + [39, 385, 78, 675, 0, 2000] + """ + + system: str = f"{IM_START}system\n{{content}}{IM_END}\n" + user: str = f"{IM_START}user\n{{content}}{IM_END}\n" + assistant: str = f"{IM_START}assistant\n{{content}}{IM_END}\n" + assistant_for_generation: str = f"{IM_START}assistant\n" + + def __init__( + self, + path: str, + merges_file: str, + special_tokens: Optional[Dict[str, int]] = None, + max_seq_len: Optional[int] = None, + *, + errors: str = "replace", + unk_token: Optional[str] = ENDOFTEXT, + bos_token: Optional[str] = None, + eos_token: str = ENDOFTEXT, + pad_token: Optional[str] = ENDOFTEXT, + bpe_cache_size: int = DEFAULT_QWEN2_TOKENIZER_BPE_CACHE_SIZE, + ): + with open(path, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + + self._bpe = lru_cache(maxsize=bpe_cache_size)(self._bpe_without_cache) + + self.pat = re.compile(PRETOKENIZE_REGEX) + + self.special_tokens = ( + special_tokens if special_tokens is not None else QWEN2_SPECIAL_TOKENS + ) + self._special_tokens_reversed = {v: k for k, v in self.special_tokens.items()} + + self.unk_id = None if unk_token is None else self.special_tokens[unk_token] + self.bos_id = None if bos_token is None else self.special_tokens[bos_token] + self.eos_id = None if eos_token is None else self.special_tokens[eos_token] + self.pad_id = None if pad_token is None else self.special_tokens[pad_token] + self.im_start_id = self.special_tokens[IM_START] + self.im_end_id = self.special_tokens[IM_END] + self.stop_tokens = [self.eos_id, self.im_end_id] + + # Pattern for special tokens. + self._pattern_split_special_tokens = re.compile( + r"(\L)", options=self.special_tokens.keys() + ) + + self.max_seq_len = max_seq_len + + def _bpe_without_cache(self, token): + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self._bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.unk_id) + + def encode( + self, text: str, add_bos: bool = True, add_eos: bool = True + ) -> List[int]: + """ + Encode a string into a list of token ids. + + Args: + text (str): The string to encode. + add_bos (bool): (Optional) Whether to add the beginning of sequence token. + add_eos (bool): (Optional) Whether to add the end of sequence token. + + Returns: + List[int]: The list of token ids. + + Note: + This method follows + and + . + """ + + text = unicodedata.normalize("NFC", text) + + tokens = self._pattern_split_special_tokens.split(text) + + tokenized_text = [] + for token in tokens: + if not token: + continue + if token in self.special_tokens: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize(token)) + + # Convert tokenized text to token ids. + token_ids = [] + if add_bos and self.bos_id is not None: + token_ids.append(self.bos_id) + for token in tokenized_text: + if token in self.special_tokens: + token_id = self.special_tokens[token] + else: + token_id = self._convert_token_to_id(token) + token_ids.append(token_id) + if add_eos and self.eos_id is not None: + token_ids.append(self.eos_id) + + return token_ids + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the vocab.""" + token = self._special_tokens_reversed.get(index, None) + if token is None: + return self.decoder.get(index) + return token + + def _convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) + return text + + def decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + ) -> str: + """ + Decode a list of token ids into a string. + + Args: + token_ids (List[int]): The list of token ids. + skip_special_tokens (bool): Whether the special tokens should be removed from the decoded string. + + Returns: + str: The decoded string. + """ + sub_texts = [] + current_sub_text = [] + for token_id in token_ids: + token = self._convert_id_to_token(token_id) + if token_id in self._special_tokens_reversed: + if current_sub_text: + string = self._convert_tokens_to_string(current_sub_text) + if string: + sub_texts.append(string) + current_sub_text = [] + if not skip_special_tokens: + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self._convert_tokens_to_string(current_sub_text)) + + text = "".join(sub_texts) + return text + + def tokenize_messages( + self, + messages: List[Message], + apply_chat_template: bool = True, + ) -> Tuple[List[int], List[bool]]: + """ + Given a list of messages, return a list of tokens for the concatenated + and formatted messages. + + Args: + messages (List[Message]): The message list to tokenize. + apply_chat_template (bool): Whether to apply Qwen2 chat template. + + Returns: + Tuple[List[int], List[bool]]: The list of token ids and the list of masks. + """ + tokens = [] + mask = [] + is_generation = False + for index, message in enumerate(messages): + content = "" + if message.role == "system": + content = self.system.format(content=message.text_content) + elif message.role == "user": + content = self.user.format(content=message.text_content) + elif message.role == "assistant": + if index == len(messages) - 1 and not message.text_content: + content = self.assistant_for_generation + is_generation = True + else: + content = self.assistant.format(content=message.text_content) + tokenized_message = self.encode(content, add_bos=False, add_eos=False) + tokens.extend(tokenized_message) + mask.extend([message.masked] * len(tokenized_message)) + + if self.max_seq_len and len(tokens) >= self.max_seq_len: + break + + if not is_generation: + tokens = tokens + [self.eos_id] + last_message_masked = False + if messages: + last_message_masked = messages[-1].masked + mask = mask + [last_message_masked] + if self.max_seq_len: + tokens = truncate(tokens, self.max_seq_len, self.eos_id) + mask = truncate(mask, self.max_seq_len, True) + return tokens, mask diff --git a/training/torchtune/modules/__init__.py b/training/torchtune/modules/__init__.py new file mode 100644 index 0000000..51878b3 --- /dev/null +++ b/training/torchtune/modules/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .attention import CausalSelfAttention # noqa +from .common_utils import reparametrize_as_dtype_state_dict_post_hook +from .feed_forward import FeedForward # noqa +from .kv_cache import KVCache # noqa +from .layer_norm import Fp32LayerNorm # noqa +from .low_precision import FrozenNF4Linear # noqa +from .lr_schedulers import get_cosine_schedule_with_warmup # noqa +from .position_embeddings import RotaryPositionalEmbeddings # noqa +from .rms_norm import RMSNorm # noqa +from .transformer import ( # noqa + TiedEmbeddingTransformerDecoder, + TransformerDecoder, + TransformerDecoderLayer, +) +from .vision_transformer import VisionTransformer + +__all__ = [ + "CausalSelfAttention", + "FeedForward", + "FrozenNF4Linear", + "get_cosine_schedule_with_warmup", + "KVCache", + "RotaryPositionalEmbeddings", + "RMSNorm", + "Fp32LayerNorm", + "VisionTransformer", + "TransformerDecoder", + "TiedEmbeddingTransformerDecoder", + "TransformerDecoderLayer", + "reparametrize_as_dtype_state_dict_post_hook", +] diff --git a/training/torchtune/modules/attention.py b/training/torchtune/modules/attention.py new file mode 100644 index 0000000..a107112 --- /dev/null +++ b/training/torchtune/modules/attention.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor +from torchtune.modules.kv_cache import KVCache +from torchtune.utils.logging import get_logger + +logger = get_logger("DEBUG") + + +class CausalSelfAttention(nn.Module): + """Multi-headed grouped query self-attention (GQA) layer introduced + in https://arxiv.org/abs/2305.13245v1. + + GQA is a version of multiheaded attention (MHA) which uses fewer + key/value heads than query heads by grouping n query heads for each + key and value head. Multi-Query Attention is an extreme + version where we have a single key and value head shared by all + query heads. + + Following is an example of MHA, GQA and MQA with num_heads = 4 + + (credit for the documentation: + https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py). + + + :: + + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ │ │ │ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + MHA GQA MQA + n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + head_dim (int): dimension of each head, calculated by ``embed_dim`` // ``num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (nn.Module): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value. + If not specified, then no caching is used. + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + + Raises: + ValueError: If `num_heads` % `num_kv_heads` != 0 + ValueError: If `embed_dim` % `num_heads` != 0 + ValueError: If `attn_dropout` < 0 or > 1 + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: nn.Module, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.pos_embeddings = pos_embeddings + + def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=self.max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: Tensor, + *, + mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [batch_size x seq_length x embed_dim] + mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + with shape [batch_size x seq_length x seq_length]. This is applied after + the query-key multiplication and before the softmax. A value of True in row i + and column j means token i attends to token j. A value of False means token i + does not attend to token j. If no mask is specified, a causal mask + is used by default. Default is None. + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with attention applied + + Raises: + ValueError: if seq_len of x is bigger than max_seq_len + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + + TODO: + - Return the attention weights + - Make application of positional embeddings optional + """ + # input has shape [b, s, d] + bsz, seq_len, _ = x.shape + + if self.kv_cache and input_pos is None: + cache_size = self.kv_cache.size + input_pos = torch.arange(cache_size, cache_size + seq_len, device=x.device) + + if seq_len > self.max_seq_len: + raise ValueError( + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" + ) + + # q has shape [b, s, num_heads * head_dim] + # k has shape [b, s, num_kv_heads * head_dim] + # v has shape [b, s, num_kv_heads * head_dim] + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + + # q: [b, s, n_kv, q_per_kv, h_d] + # k: [b, s, n_kv, 1, h_d] + # v: [b, s, n_kv, 1, h_d] + q = q.view(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # if needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) + + # llama2 applies the RoPE embeddings on tensors with shape + # [b, s, n_h, h_d] + # Reshape the tensors before we apply RoPE + q = q.reshape(bsz, seq_len, -1, self.head_dim) + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # Apply positional embeddings + q = self.pos_embeddings(q, input_pos=input_pos) + k = self.pos_embeddings(k, input_pos=input_pos) + + # [b, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update key-value cache + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + # shape: [b, 1, s, s] + if mask is not None: + mask = mask[:, None, :, :] + + # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ + output = nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.attn_dropout, + is_causal=self.kv_cache is None and mask is None, + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return self.output_proj(output) diff --git a/training/torchtune/modules/common_utils.py b/training/torchtune/modules/common_utils.py new file mode 100644 index 0000000..9c588fa --- /dev/null +++ b/training/torchtune/modules/common_utils.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Tuple + +import torch + +import torch.nn as nn +from torchao.dtypes.nf4tensor import NF4Tensor + + +def reparametrize_as_dtype_state_dict_post_hook( + model: nn.Module, + state_dict: Dict[str, Any], + *args: Tuple[Any, ...], + dtype: torch.dtype = torch.bfloat16, + offload_to_cpu: bool = True, + **kwargs: Dict[Any, Any], +): + """ + A state_dict hook that replaces NF4 tensors with their restored + higher-precision weight and optionally offloads the restored weight to CPU. + Use this hook to avoid increased peak GPU memory usage during checkpoint + save when training with QLoRA. + + This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e. + + >>> m = MyModule() + >>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook) + + If the hook is registered per the above process, this hook will be called _after_ the module's + ``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing + them to the original dtype, and optionally offload the restored weight to CPU. + + Args: + model (nn.Module): the model to take ``state_dict()`` on + state_dict (Dict[str, Any]): the state dict to modify + *args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook. + dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``. + offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``. + **kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook. + """ + for k, v in state_dict.items(): + if isinstance(v, NF4Tensor): + state_dict[k] = v.to(dtype) + if offload_to_cpu: + state_dict[k] = state_dict[k].cpu() diff --git a/training/torchtune/modules/feed_forward.py b/training/torchtune/modules/feed_forward.py new file mode 100644 index 0000000..c69cd17 --- /dev/null +++ b/training/torchtune/modules/feed_forward.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn, Tensor + + +class FeedForward(nn.Module): + """This class implements the feed-forward network derived from Llama2. + + Args: + gate_proj (nn.Module): Projection from input dim to hidden dim, fed through activation + and multiplied by up_proj. + down_proj (nn.Module): Final projection to output dim. + up_proj (nn.Module): Projection from input dim to hidden dim, multiplied by + activation(gate_proj). + activation (nn.Module): Activation function to use. Default is nn.SiLU(). + """ + + def __init__( + self, + *, + gate_proj: nn.Module, + down_proj: nn.Module, + up_proj: nn.Module, + activation: nn.Module = nn.SiLU(), + ): + super().__init__() + self.w1 = gate_proj + self.w2 = down_proj + self.w3 = up_proj + self.activation = activation + + def forward(self, x: Tensor) -> Tensor: + return self.w2(self.activation(self.w1(x)) * self.w3(x)) diff --git a/training/torchtune/modules/kv_cache.py b/training/torchtune/modules/kv_cache.py new file mode 100644 index 0000000..1ad55a4 --- /dev/null +++ b/training/torchtune/modules/kv_cache.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from torch import nn, Tensor + + +class KVCache(nn.Module): + """ + Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference. + + Args: + batch_size (int): batch size model will be run with + max_seq_len (int): maximum sequence length model will be run with + num_heads (int): number of heads. We take num_heads instead of num_kv_heads because + the cache is created after we've expanded the key and value tensors to have the + same shape as the query tensor. See attention.py for more details + head_dim (int): per-attention head embedding dimension + dtype (torch.dtype): dtype for the caches + """ + + def __init__( + self, + batch_size: int, + max_seq_len: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + ) -> None: + super().__init__() + cache_shape = (batch_size, num_heads, max_seq_len, head_dim) + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False + ) + self.size = 0 + self.batch_size = batch_size + + def reset(self) -> None: + """Reset the cache to zero.""" + self.k_cache.zero_() + self.v_cache.zero_() + + def update( + self, input_pos: Tensor, k_val: Tensor, v_val: Tensor + ) -> Tuple[Tensor, Tensor]: + """Update KV cache with the new k_val, v_val and return the updated cache. + + Raises an assertion error if ``input_pos`` is longer than the maximum sequence length. + + Args: + input_pos (Tensor): Current position tensor with shape [S] + k_val (Tensor): Current key tensor with shape [B, H, S, D] + v_val (Tensor): Current value tensor with shape [B, H, S, D] + + Returns: + Tuple[Tensor, Tensor]: Updated KV cache with key first + """ + assert input_pos.shape[0] == k_val.shape[2] + self.size = input_pos.max().item() + 1 + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out diff --git a/training/torchtune/modules/layer_norm.py b/training/torchtune/modules/layer_norm.py new file mode 100644 index 0000000..a2154e7 --- /dev/null +++ b/training/torchtune/modules/layer_norm.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any + +import torch +from torch import nn + + +class Fp32LayerNorm(nn.LayerNorm): + """ + Wrapper around :class:`~torch.nn.LayerNorm` to support mixed-precision training. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The normalized output tensor. + """ + output = nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) diff --git a/training/torchtune/modules/loss/__init__.py b/training/torchtune/modules/loss/__init__.py new file mode 100644 index 0000000..f44df6b --- /dev/null +++ b/training/torchtune/modules/loss/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from .dpo import DPOLoss, IPOLoss, RSOLoss, SimPOLoss +from .ppo import PPOLoss + +__all__ = ["DPOLoss", "RSOLoss", "IPOLoss", "SimPOLoss", "PPOLoss"] diff --git a/training/torchtune/modules/loss/dpo.py b/training/torchtune/modules/loss/dpo.py new file mode 100644 index 0000000..af9ca25 --- /dev/null +++ b/training/torchtune/modules/loss/dpo.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DPOLoss(nn.Module): + """ + Direct Preference Optimization (DPO) Loss module: https://arxiv.org/abs/2305.18290. + Simply stated from the paper: + + Intuitively, the DPO update increases the relative log probability of preferred to dispreferred responses, + but it incorporates a dynamic, per-example importance weight that prevents + the model degeneration that we find occurs with a naive probability ratio objective. + + Based on the implementation in HF's TRL library: + https://github.com/huggingface/trl/blob/5d1deb1445828cfd0e947cb3a7925b1c03a283fc/trl/trainer/dpo_trainer.py#L844 + + DPO retains similarities to PPO (https://arxiv.org/abs/2009.01325), where it optimizes a policy + (language) model to align with human preferences, and regularizes the loss function using a baseline + reference (the frozen, initial language model) to prevent over-fitting to the preference dataset. + It differs from PPO by optimizing the policy model directly using labelled preference data, rather + than using an additional reward model to provide feedback. + This significantly simplifies training and reduces compute overhead. + + Args: + beta (float): Temperature parameter for the DPO loss, typically in the range of 0.1 to 0.5. Default is 0.1. + label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0. + """ + + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + ): + super().__init__() + self.beta = beta + self.label_smoothing = label_smoothing + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps (torch.Tensor): Log probabilities of the policy model + for the chosen responses. Shape: (batch_size) + policy_rejected_logps (torch.Tensor): Log probabilities of the policy model + for the rejected responses. Shape: (batch_size) + reference_chosen_logps (torch.Tensor): Log probabilities of the reference model + for the chosen responses. Shape: (batch_size) + reference_rejected_logps (torch.Tensor): Log probabilities of the reference model + for the rejected responses. Shape: (batch_size) + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: + - losses: The DPO loss for each example in the batch. + - chosen_rewards: Rewards for the chosen responses. + - rejected_rewards: Rewards for the rejected responses. + + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative DPO loss. + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + + chosen_rewards = ( + self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + ) + rejected_rewards = ( + self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + ) + + return losses, chosen_rewards, rejected_rewards + + +class RSOLoss(nn.Module): + """ + Statistical Rejection Sampling Optimization (RSO) or "hinge" loss module: https://arxiv.org/abs/2309.06657. + Intuition from the paper: + + DPO is a logistic regression on human preference data, and SLiC (https://arxiv.org/abs/2305.10425) is almost + equivalent to a support vector machine (SVM) with hinge loss. [RSO] improve[s] SLiC as the SVM counter part of DPO. + + Based on the implementation in HF's TRL library: + https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1141 + + Args: + gamma (float): Equivalent temperature parameter (from DPO) for the RSO loss. + """ + + def __init__( + self, + gamma: float = 0.1, + ): + super().__init__() + self.gamma = gamma + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the RSO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps (torch.Tensor): Log probabilities of the policy model + for the chosen responses. Shape: (batch_size) + policy_rejected_logps (torch.Tensor): Log probabilities of the policy model + for the rejected responses. Shape: (batch_size) + reference_chosen_logps (torch.Tensor): Log probabilities of the reference model + for the chosen responses. Shape: (batch_size) + reference_rejected_logps (torch.Tensor): Log probabilities of the reference model + for the rejected responses. Shape: (batch_size) + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: + - losses: The RSO loss for each example in the batch. + - chosen_rewards: Rewards for the chosen responses. + - rejected_rewards: Rewards for the rejected responses. + + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + losses = torch.relu(1 - self.gamma * logits) + + chosen_rewards = ( + self.gamma * (policy_chosen_logps - reference_chosen_logps).detach() + ) + rejected_rewards = ( + self.gamma * (policy_rejected_logps - reference_rejected_logps).detach() + ) + + return losses, chosen_rewards, rejected_rewards + + +class IPOLoss(nn.Module): + """ + Identity Preference Optimisation (IPO) Loss module: https://arxiv.org/abs/2310.12036. + Intuition from the paper: + + (Given a policy pi and reference policy, pi_ref) + + IPO learns from preferences dataset simply by regressing the gap between log-likelihood ratios + + log(pi(chosen)/pi(rejected)) and log(pi_ref(chosen)/pi_ref(rejected)) + + to 1/(2*tau), where tau is the temperature parameter. [T]he weaker the regularisation becomes, the + higher would be the log-likelihood ratio of chosen to rejected logprobs. In other words IPO, unlike DPO, + always regularizes its solution towards pi_ref by controlling the gap between the log-likelihood ratios + + log(pi(chosen)/pi(rejected)) and log(pi_ref(chosen)/pi_ref(rejected)) + + thus avoiding the over-fitting to the preference dataset. + + Based on the implementation in HF's TRL library: + https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1143 + + + Args: + tau (float): Equivalent temperature scaling parameter (from DPO) for the IPO loss. From the TRL documentation: + + the [tau] parameter is the reciprocal of the gap between the log-likelihood ratios of the + chosen vs the rejected completion pair and thus the smaller the tau the larger this gap is. + """ + + def __init__( + self, + tau: float = 0.1, + ): + super().__init__() + self.tau = tau + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps (torch.Tensor): Log probabilities of the policy model + for the chosen responses. Shape: (batch_size) + policy_rejected_logps (torch.Tensor): Log probabilities of the policy model + for the rejected responses. Shape: (batch_size) + reference_chosen_logps (torch.Tensor): Log probabilities of the reference model + for the chosen responses. Shape: (batch_size) + reference_rejected_logps (torch.Tensor): Log probabilities of the reference model + for the rejected responses. Shape: (batch_size) + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: + - losses: The DPO loss for each example in the batch. + - chosen_rewards: Rewards for the chosen responses. + - rejected_rewards: Rewards for the rejected responses. + + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + losses = (logits - 1 / (2 * self.tau)) ** 2 + + chosen_rewards = ( + self.tau * (policy_chosen_logps - reference_chosen_logps).detach() + ) + rejected_rewards = ( + self.tau * (policy_rejected_logps - reference_rejected_logps).detach() + ) + + return losses, chosen_rewards, rejected_rewards + + +class SimPOLoss(nn.Module): + """ + SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734. + Intuition from the paper: + + The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as + the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to + encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. + + Based on the TRL implementation: + https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603 + + SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize + the policy during training. It also uses a target reward margin to guide the policy towards better responses. + This is kind of the same intuition in:class:`~torchtune.modules.loss.IPO`, but instead of optimizing against a margin + between the reference policy and policy models, we're optimizing against a margin between the chosen and rejected responses. + + Args: + beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0. + gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``. + Default is 0.5. + label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0. + """ + + def __init__( + self, + beta: float = 2.0, + gamma: float = 0.5, + label_smoothing: float = 0.0, + ): + super().__init__() + self.beta = beta + self.gamma = gamma + self.label_smoothing = label_smoothing + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the SimPO loss for a batch chosen and rejected average log probabilities. + + Args: + policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model + for the chosen responses with shape [b,]. + policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model + for the rejected responses with shape [b,]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]: + - losses: The SimPO loss for each example in the batch. + - chosen_rewards: Rewards for the chosen responses. + - rejected_rewards: Rewards for the rejected responses. + """ + + pi_logratios = policy_chosen_logps - policy_rejected_logps + + gamma_logratios = self.gamma / self.beta + logits = pi_logratios - gamma_logratios + + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + + chosen_rewards = self.beta * (policy_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards diff --git a/training/torchtune/modules/loss/ppo.py b/training/torchtune/modules/loss/ppo.py new file mode 100644 index 0000000..0cef4a5 --- /dev/null +++ b/training/torchtune/modules/loss/ppo.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torchtune.modules import rlhf + + +class PPOLoss(nn.Module): + """ + Proximal Policy Optimization (PPO) Loss module. + This implementation uses the following references: + + https://arxiv.org/abs/1707.06347 eqn. 7 + + https://github.com/vwxyzjn/lm-human-preference-details/blob/ccc19538e817e98a60d3253242ac15e2a562cb49/lm_human_preference_details/train_policy_accelerate.py#L719 + + https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75 + + + Args: + epsilon (float): clipping range for PPO update. + value_clip_range (float): clipping range for value function update. + value_coeff (float): coefficient for the value function loss contribution. + """ + + def __init__( + self, + epsilon: float = 0.1, + value_clip_range: float = 0.2, + value_coeff: float = 0.1, + ): + super().__init__() + self.epsilon = epsilon + self.value_clip_range = value_clip_range + self.value_coeff = value_coeff + + def forward( + self, + pi_old_logprobs: torch.Tensor, + pi_logprobs: torch.Tensor, + advantages: torch.Tensor, + phi_old_values: torch.Tensor, + phi_values: torch.Tensor, + returns: torch.Tensor, + padding_masks: Optional[torch.Tensor] = None, + value_padding_masks: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + + Forward pass of the PPO loss module. + + Args: + pi_old_logprobs (torch.Tensor): Log probabilities of the old policy. + pi_logprobs (torch.Tensor): Log probabilities of the current policy. + advantages (torch.Tensor): Advantage values. + phi_old_values (torch.Tensor): Value predictions of the old value function. + phi_values (torch.Tensor): Value predictions of the current value function. + returns (torch.Tensor): Return values. + padding_masks (Optional[torch.Tensor]): Padding token masks of the same shape as ``pi_logprobs``, + where True indicates the corresponding loss values should participage in policy loss calculation. + value_padding_masks (Optional[torch.Tensor]): Padding token masks of the same shape as ``pi_logprobs``, + where True indicates the corresponding loss values should participage in value loss calculation. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of five tensors: + - loss: The total PPO loss. + - policy_loss: The policy function loss. + - value_loss: The value function loss. + - ratios: The ratio between the current and old policy probabilities. + - clipfrac: The fraction of ratios that were clipped. + + """ + ratios = torch.exp(pi_logprobs - pi_old_logprobs) + clipped_ratios = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) + + policy_losses_clipped = -advantages * clipped_ratios + policy_losses_unclipped = -advantages * ratios + + clipfrac = (policy_losses_clipped > policy_losses_unclipped).float() + clipfrac = ( + clipfrac.mean() + if padding_masks is None + else rlhf.masked_mean(clipfrac, padding_masks) + ) + + policy_loss = torch.maximum(policy_losses_clipped, policy_losses_unclipped) + policy_loss = ( + policy_loss.mean() + if padding_masks is None + else rlhf.masked_mean(policy_loss, padding_masks) + ) + + values_clipped = torch.clamp( + phi_values, + phi_old_values - self.value_clip_range, + phi_old_values + self.value_clip_range, + ) + value_loss = torch.maximum( + (phi_values - returns) ** 2, (values_clipped - returns) ** 2 + ) + value_loss = ( + 0.5 * value_loss.mean() + if value_padding_masks is None + else 0.5 * rlhf.masked_mean(value_loss, value_padding_masks) + ) + + loss = policy_loss + (value_loss * self.value_coeff) + return ( + loss, + policy_loss.detach(), + value_loss.detach(), + ratios.mean().detach(), + clipfrac.detach(), + ) diff --git a/training/torchtune/modules/low_precision/__init__.py b/training/torchtune/modules/low_precision/__init__.py new file mode 100644 index 0000000..8bf6448 --- /dev/null +++ b/training/torchtune/modules/low_precision/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .nf4_linear import FrozenNF4Linear + +__all__ = [ + "FrozenNF4Linear", +] diff --git a/training/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py b/training/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py new file mode 100644 index 0000000..b057a1f --- /dev/null +++ b/training/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchao.dtypes.nf4tensor import implements as nf4_tensor_impl, to_nf4 +from torchtune.modules.low_precision._utils import _get_torchao_version + + +@nf4_tensor_impl([torch.ops.aten.clone.default]) +def clone(func, *args, **kwargs): + """ + __torch_dispatch__ override that is called when cloning an NF4Tensor. + This is implemented by creating a new NF4Tensor with the unquantized weight + of the input tensor. Note that this is not an exact "clone" due to the loss + in precision. + """ + return to_nf4(args[0][0].get_original_weight()) + + +should_define_inplace_copy = True +ao_version, is_nightly = _get_torchao_version() +if ao_version: + if (is_nightly and ao_version >= "2024.5.20") or ( + not is_nightly and ao_version >= "0.2.0" + ): + should_define_inplace_copy = False + +if should_define_inplace_copy: + # TorchAO have `NF4.copy_` starting from `0.2.0` + # it's a superset of `inplace_copy` since it covers `NF4.copy_(NF4)` + @nf4_tensor_impl([torch.ops.aten.copy_.default]) + def inplace_copy(func, *args, **kwargs): + """ + Performs an inplace copy of an incoming tensor into the tensor + being copied into. The inplace tensor is given by args[0][1] and the + tensor being copied into is given by args[0][0]. The copy is performed + by copying over all attributes. This method would have to be updated + if additional attributes are added to NF4Tensor. + """ + dest_tensor = args[0][0] # tensor we are inplace copying into + ref_tensor = to_nf4( + args[0][1].to(dest_tensor.device) + ) # TODO check if nf4 tensor takes in device arg + dest_tensor.block_size = ref_tensor.block_size + dest_tensor.n_blocks = ref_tensor.n_blocks + dest_tensor.scaler_block_size = ref_tensor.scaler_block_size + dest_tensor.quantized_scalers = ref_tensor.quantized_scalers + dest_tensor.quantization_factor = ref_tensor.quantization_factor + dest_tensor.scaler_mean = ref_tensor.scaler_mean + dest_tensor.quantized_data = ref_tensor.quantized_data + dest_tensor.nf4 = ref_tensor.nf4 diff --git a/training/torchtune/modules/low_precision/_utils.py b/training/torchtune/modules/low_precision/_utils.py new file mode 100644 index 0000000..60ddf0b --- /dev/null +++ b/training/torchtune/modules/low_precision/_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from datetime import datetime +from importlib.metadata import PackageNotFoundError, version +from typing import Optional, Tuple + +import torch + +import torchao + + +def _is_fbcode(): + return not hasattr(torch.version, "git_version") + + +def _nightly_version_ge(ao_version_str: str, date: str) -> bool: + """ + Compare a torchao nightly version to a date of the form + %Y-%m-%d. + + Returns True if the nightly version is greater than or equal to + the date, False otherwise + """ + ao_datetime = datetime.strptime(ao_version_str.split("+")[0], "%Y.%m.%d") + return ao_datetime >= datetime.strptime(date, "%Y-%m-%d") + + +def _get_torchao_version() -> Tuple[Optional[str], Optional[bool]]: + """ + Get torchao version. Returns a tuple of two elements, the first element + is the version string, the second element is whether it's a nightly version. + For fbcode usage, return None, None. + + Checks: + 1) is_fbcode, then + 2) importlib's version(torchao-nightly) for nightlies, then + 3) torchao.__version__ (only defined for torchao >= 0.3.0), then + 4) importlib's version(torchao) for non-nightly + + + If none of these work, raise an error. + + """ + if _is_fbcode(): + return None, None + # Check for nightly install first + try: + ao_version = version("torchao-nightly") + is_nightly = True + except PackageNotFoundError: + try: + ao_version = torchao.__version__ + is_nightly = False + except AttributeError: + ao_version = "unknown" + if ao_version == "unknown": + try: + ao_version = version("torchao") + is_nightly = False + except Exception as e: + raise PackageNotFoundError("Could not find torchao version") from e + return ao_version, is_nightly diff --git a/training/torchtune/modules/low_precision/nf4_linear.py b/training/torchtune/modules/low_precision/nf4_linear.py new file mode 100644 index 0000000..6626688 --- /dev/null +++ b/training/torchtune/modules/low_precision/nf4_linear.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +import torch.nn as nn +from torch import Tensor +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + +class FrozenNF4Linear(nn.Linear): + """ + A linear layer similar to ``torch.nn.Linear`` but uses a quantized + NF4Tensor as its weight. This class also freezes its ``weight`` parameter + and is meant to be used as the base Linear layer for modeling + use cases such as QLoRA where base model parameters are frozen. + NOTE: biases are currently not supported. + + Args: + in_dim (int): input dimension + out_dim (int): output dimension + device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default + device given by `torch.get_default_device()`. + **kwargs: any additional arguments to pass to the underlying Linear layer. + + Raises: + RuntimeError: if ``bias`` is set to ``True`` + """ + + def __init__( + self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs + ): + if "bias" in kwargs and kwargs.pop("bias"): + raise RuntimeError("FrozenNF4Linear does not currently support biases!") + + super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs) + self.weight.requires_grad_(False) + self.nf4_weight = to_nf4(self.weight) + # re-register self.weight as the nf4 weight, so that the nf4 weight + # shows up as expected in .parameters, state_dict, etc. + torch.utils.swap_tensors( + self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False) + ) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs linear operation with input tensor as given by `input`. Computation happens in higher + precision, though only the nf4 weight is saved for backward for gradient computation to ensure + additional memory is not used. + Args: + input (Tensor): input tensor + + Returns: + Tensor: output tensor + """ + return linear_nf4(input=input, weight=self.weight) diff --git a/training/torchtune/modules/lr_schedulers.py b/training/torchtune/modules/lr_schedulers.py new file mode 100644 index 0000000..188f0ff --- /dev/null +++ b/training/torchtune/modules/lr_schedulers.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch.optim.lr_scheduler import LambdaLR + + +def get_cosine_schedule_with_warmup( + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +) -> LambdaLR: + """ + Create a learning rate schedule that linearly increases the learning rate from + 0.0 to lr over num_warmup_steps, then decreases to 0.0 on a cosine schedule over + the remaining num_training_steps-num_warmup_steps (assuming num_cycles = 0.5). + + This is based on the Hugging Face implementation + https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/optimization.py#L104. + + Args: + optimizer (torch.optim.Optimizer): The optimizer for which to + schedule the learning rate. + num_warmup_steps (int): The number of steps for the warmup phase. + num_training_steps (int): The total number of training steps. + num_cycles (float): The number of waves in the cosine schedule. Defaults to 0.5 + (decrease from the max value to 0 following a half-cosine). + last_epoch (int): The index of the last epoch when resuming training. Defaults to -1 + + Returns: + torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. + """ + + def lr_lambda(current_step: int) -> float: + # linear warmup phase + if current_step < num_warmup_steps: + return current_step / max(1, num_warmup_steps) + + # cosine + progress = (current_step - num_warmup_steps) / max( + 1, num_training_steps - num_warmup_steps + ) + + cosine_lr_multiple = 0.5 * ( + 1.0 + math.cos(math.pi * num_cycles * 2.0 * progress) + ) + return max(0.0, cosine_lr_multiple) + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/training/torchtune/modules/peft/__init__.py b/training/torchtune/modules/peft/__init__.py new file mode 100644 index 0000000..770ac41 --- /dev/null +++ b/training/torchtune/modules/peft/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .lora import LoRALinear +from .peft_utils import ( # noqa + AdapterModule, + disable_adapter, + get_adapter_params, + LORA_ATTN_MODULES, + set_trainable_params, + validate_missing_and_unexpected_for_lora, + validate_state_dict_for_lora, +) + +__all__ = [ + "LoRALinear", + "AdapterModule", + "get_adapter_params", + "set_trainable_params", + "validate_missing_and_unexpected_for_lora", + "validate_state_dict_for_lora", + "disable_adapter", +] diff --git a/training/torchtune/modules/peft/lora.py b/training/torchtune/modules/peft/lora.py new file mode 100644 index 0000000..2ba53dd --- /dev/null +++ b/training/torchtune/modules/peft/lora.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import math +from typing import List + +import torch.nn.functional as F + +from torch import nn, Tensor + +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 +from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401 +from torchtune.modules.peft.peft_utils import AdapterModule + + +class LoRALinear(nn.Module, AdapterModule): + """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `_. + + LoRA perturbs a given layer via a low-rank approximation where only + the rank decomposition matrices are trainable. In a linear layer instead of + :math:`x \\mapsto W_0x` a LoRALinear layer is defined as + :math:`x \\mapsto W_0x + (\\alpha / r)BAx`, where :math:`r` is the rank of + the matrices :math:`A` and :math:`B` and :math:`\\alpha` is a scaling factor. + As in the original implementation, we support dropout before multiplication + by the low-rank matrices. + + Args: + in_dim (int): input dimension + out_dim (int): output dimension + rank (int): rank of the low-rank approximation + alpha (float): scaling factor for the low-rank approximation + dropout (float): dropout probability. Default: 0.0 + use_bias (bool): whether to include bias in the original linear layer. + Default: False + quantize_base (bool): Whether to quantize base linear weight or not. + Default: False + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + quantize_base: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.rank = rank + self.alpha = alpha + self.out_dim = out_dim + self.use_bias = use_bias + self._quantize_base = quantize_base + weight, bias = self._create_weight_and_bias() + # 'self.disabled' is a flag showing whether to turn off LoRA adapters, + # this can be used in DPO for treating the lora adapters as the policy model + # and disabling it to treat the base model as the reference model + self.disabled = False + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + self.dropout = nn.Dropout(p=dropout) + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + self.merged = False + # Note: FSDP's meta device initialization contract assumes that a module's + # reset_parameters method only initializes its own parameters (i.e. no child + # params are initialized, as is done in initialize_parameters below). + # For that reason, we patch reset_parameters directly on lora_a and lora_b submodules + # when using meta device. This is done in + # torchtune.utils.prepare_model_for_fsdp_with_meta_device. + # See this issue for more details: https://github.com/pytorch/pytorch/issues/104187. + # Without meta device, we only need the following: + self.initialize_parameters() + + def initialize_parameters(self): + # Initialize as in + # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 + _lora_a_init_params(self.lora_a) + _lora_b_init_params(self.lora_b) + + def _create_weight_and_bias(self): + """ + Creates a linear weight and bias tensor, using NF4 dtype if we're quantizing + (indicated via quantize_base=True). + """ + in_dim, out_dim, use_bias = self.in_dim, self.out_dim, self.use_bias + linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias) + weight = linear.weight if not self._quantize_base else to_nf4(linear.weight) + bias = None + if self.use_bias: + if self._quantize_base: + raise NotImplementedError( + "Quantized LoRALinear does not support bias at the moment." + ) + bias = linear.bias + return weight, bias + + def adapter_params(self) -> List[str]: + """ + Return lora_a.weight and lora_b.weight as adapter params. + If bias is enabled, also return lora_a.bias and lora_b.bias. + """ + # NOTE: this function has to be updated if the names of "lora_a" and "lora_b" + # in this module change. + adapter_params = ["lora_a.weight", "lora_b.weight"] + return adapter_params + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape ``(..., in_dim)`` + + Returns: + Tensor: output tensor with shape ``(..., out_dim)`` + + """ + if self._quantize_base: + out = linear_nf4(input=x, weight=self.weight) + else: + out = F.linear(x, self.weight, self.bias) + if self.disabled: + return out + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + return out + lora_out + + +def _lora_a_init_params(x: nn.Linear) -> None: + """ + Initialize LoRA A weight to Kaiming uniform. + """ + nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) + + +def _lora_b_init_params(x: nn.Linear) -> None: + """ + Initialize LoRA B weight to zeros. + """ + nn.init.zeros_(x.weight) diff --git a/training/torchtune/modules/peft/peft_utils.py b/training/torchtune/modules/peft/peft_utils.py new file mode 100644 index 0000000..eb86c6a --- /dev/null +++ b/training/torchtune/modules/peft/peft_utils.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set + +from torch import nn + +# Modules from CausalSelfAttention that LoRA can be applied to +LORA_ATTN_MODULES = Literal["q_proj", "k_proj", "v_proj", "output_proj"] + + +class AdapterModule(Protocol): + """ + Interface for an nn.Module containing adapter weights. + Note that an adapter module does not have to explicitly implement this protocol, + but it must define the ``adapter_params(self)`` method. + """ + + def adapter_params(self) -> List[str]: + """ + Return a list of strings corresponding to the names of the nn.Parameters in + the model coming from the adapter. + E.g. if an nn.Module has adapter ``self.proj = nn.Linear(in_dim, out_dim)``, + then adapter_params should return ``['proj.weight', 'proj.bias']``. + + See LoRALinear's :func:`~torchtune.modules.peft.LoRALinear.adapter_params` for an example. + """ + pass + + +def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: + """ + Return the subset of parameters from a model that correspond to an adapter. + Assumes that any adapter class has defined the + :func:`~torchtune.modules.peft.AdapterModule.adapter_params` method. + + Args: + model (nn.Module): Instance of model class containing some adapter params. + + Returns: + Dict[str, nn.Parameter]: the subset of model's state dict containing + only adapter parameters. + + """ + adapter_params = {} + for k, v in model.named_modules(): + if hasattr(v, "adapter_params") and callable(v.adapter_params): + current_adapter_params = v.adapter_params() + for n, p in v.named_parameters(recurse=True): + if n in current_adapter_params: + full_key = f"{k}.{n}" if k else n + adapter_params.update({full_key: p}) + current_adapter_params.remove(n) + assert ( + current_adapter_params == [] + ), f"Adapter params {current_adapter_params} not converted" + return adapter_params + + +def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> None: + """ + Set trainable parameters for an nn.Module based on a state dict of adapter parameters. + + Args: + model (nn.Module): Instance of model class containing some adapter params. + adapter_params (Dict[str, Any]): State dict mapping adapter key names to their + respective nn.Parameters (i.e. outputs of :func:`~torchtune.modules.peft.get_adapter_params`.) + + Returns: + None + """ + for k, v in model.named_parameters(): + v.requires_grad_(k in adapter_params) + + +def get_lora_module_names( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool, + apply_lora_to_output: bool, +) -> List[str]: + """ + Return a list of the names of modules in the model that have LoRA applied. Note that + the names here are local to their modules and not the fully qualified names from the + model state dict. + + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether LoRA is applied to each MLP linear. + apply_lora_to_output (bool): whether LoRA is applied to the final output projection. + + Returns: + List[str]: list of module names in the model that have LoRA applied. + """ + lora_module_keys = lora_attn_modules + if apply_lora_to_mlp: + lora_module_keys = lora_module_keys + ["w1", "w2", "w3"] + if apply_lora_to_output: + lora_module_keys.append("output") + return lora_module_keys + + +def validate_state_dict_for_lora( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool, + apply_lora_to_output: bool, + full_model_state_dict_keys: List[str], + lora_state_dict_keys: Optional[List[str]] = None, + base_model_state_dict_keys: Optional[List[str]] = None, +) -> None: + """ + Validate that the state dict keys for a LoRA model are as expected. + + (1) If lora_state_dict_keys are passed, this function will confirm that they match exactly the + LoRA param names from the full model (as determined by lora_modules). + (2) If base_model_state_dict_keys are passed, this function will confirm that they are exactly the + complement of the LoRA param names from the full model. + (3) If both lora_state_dict_keys and base_model_state_dict_keys are passed, this function will + confirm that the full model's params are exactly their disjoint union. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether LoRA is applied to each MLP linear. + apply_lora_to_output (bool): whether LoRA is applied to the final output projection. + full_model_state_dict_keys (List[str]): List of keys in the full model state dict. + lora_state_dict_keys (Optional[List[str]]): List of keys in the LoRA state dict. + If none, LoRA state dict keys will not be validated. + base_model_state_dict_keys (Optional[List[str]]): List of keys in the base model state dict. + If none, base model keys will not be validated. + + Returns: + None + + Raises: + AssertionError: If base model state dict is missing any non-LoRA params from the full model. + AssertionError: If LoRA state dict is missing any LoRA params from the full model. + AssertionError: If base model state dict has any LoRA params. + AssertionError: If LoRA state dict has any non-LoRA params. + AssertionError: If base model and LoRA state dicts have overlapping keys. + AssertionError: If full model state dict is missing keys from either base model or LoRA state dict. + + """ + lora_modules = get_lora_module_names( + lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output + ) + is_lora_param = lambda x: any([".".join([k, "lora"]) in x for k in lora_modules]) + for k in full_model_state_dict_keys: + if not is_lora_param(k): + if base_model_state_dict_keys is not None: + if k not in base_model_state_dict_keys: + raise AssertionError( + f"Missing non-LoRA key {k} from base model state dict" + ) + if lora_state_dict_keys is not None: + if k in lora_state_dict_keys: + raise AssertionError(f"Non-LoRA key {k} found in LoRA state dict") + else: + if base_model_state_dict_keys is not None: + if k in base_model_state_dict_keys: + raise AssertionError(f"LoRA key {k} found in base model state dict") + if lora_state_dict_keys is not None: + if k not in lora_state_dict_keys: + raise AssertionError(f"Missing LoRA key {k} From LoRA state dict") + + # Full model is disjoint union of base model and LoRA weights + if lora_state_dict_keys is not None and base_model_state_dict_keys is not None: + combined_state_dict_keys = set(lora_state_dict_keys).union( + base_model_state_dict_keys + ) + shared_state_dict_keys = set(lora_state_dict_keys).intersection( + base_model_state_dict_keys + ) + assert ( + shared_state_dict_keys == set() + ), "Base model and LoRA state dict have overlapping keys" + assert combined_state_dict_keys == set( + full_model_state_dict_keys + ), "Extra keys not present in full model" + + +def _get_lora_modules(state_dict: Dict[str, Any]) -> Set[str]: + """ + Get the keys from a state dict that correspond to LoRALinear modules. + + For example, if state_dict is the state dict of model and model.x.y.z is a + LoRALinear, this method will return "model.x.y.z", not + "model.x.y.z.lora_a.weight" or "model.x.y.z.lora_b.weight". + + Args: + state_dict (Dict[str, Any]): State dict from a model. + + Returns: + Set[str]: Set of keys in the state dict that correspond to LoRA modules. + """ + lora_keys = [k for k in state_dict.keys() if "lora" in k] + return set( + [ + k.replace(".lora_a.weight", "").replace(".lora_b.weight", "") + for k in lora_keys + ] + ) + + +def get_merged_lora_ckpt( + state_dict: Dict[str, Any], rank: int, alpha: float +) -> Dict[str, Any]: + """ + Merge LoRA weights into the base model format for efficient inference. + NOTE: This function modifies state_dict inplace. If you do not want to do that, + make a copy prior to calling this function. + + For every LoRA module in the state dict, this function will convert its + weight -> weight + (alpha / rank) * lora_b @ lora_a, + then delete the lora_a and lora_b weights. + + Args: + state_dict (Dict[str, Any]): State dict from a model. + rank (int): The rank of LoRA matrices. + alpha (float): The alpha value used for scaling LoRA decompositions. + + Returns: + Dict[str, Any]: The merged state dict. + """ + lora_modules = _get_lora_modules(state_dict) + for module in lora_modules: + lora_a_weight = state_dict[f"{module}.lora_a.weight"] + lora_b_weight = state_dict[f"{module}.lora_b.weight"] + state_dict[f"{module}.weight"] += (alpha / rank) * lora_b_weight @ lora_a_weight + del state_dict[f"{module}.lora_a.weight"] + del state_dict[f"{module}.lora_b.weight"] + return state_dict + + +@contextlib.contextmanager +def disable_adapter(model: nn.Module) -> Generator[None, None, None]: + """ + Temporarily disable the adapters in a neural network model. This can be used, + for example, in DPO for treating the lora adapters as the policy model + and disabling it to treat the base model as the reference model. + + This context manager goes through all modules in the provided neural network model, + and if a module has an 'adapter_params' attribute that is callable and a 'disabled' attribute, + it sets 'disabled' to True. Then, the control is given back to caller. Once that finalizes, + it sets 'disabled' back to False for all modules that were temporarily disabled. + + Args: + model (nn.Module): The neural network model whose adapters are to be temporarily disabled. + Yields: + None: This function yields control back to the caller, with the adapters disabled. + Example: + >>> with disable_adapter(model): + ... # Perform operations with adapters disabled + ... pass + + """ + for _, module in model.named_modules(): + if ( + hasattr(module, "adapter_params") + and callable(module.adapter_params) + and hasattr(module, "disabled") + ): + module.disabled = True + try: + yield + finally: + for _, module in model.named_modules(): + if ( + hasattr(module, "adapter_params") + and callable(module.adapter_params) + and hasattr(module, "disabled") + ): + module.disabled = False + + +def validate_missing_and_unexpected_for_lora( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool, + apply_lora_to_output: bool, + base_missing: Optional[List[str]] = None, + base_unexpected: Optional[List[str]] = None, + lora_missing: Optional[List[str]] = None, + lora_unexpected: Optional[List[str]] = None, +) -> None: + """ + A more memory-efficient way to validate that LoRA state dict loading was done properly. + + Similar to :func:`validate_state_dict_for_lora`, this function uses a model's LoRA config to + check that LoRA and/or base model weights are loaded into the full model correctly. + Unlike that function, this method relies only on the values of missing and unexpected + as returned by the load_state_dict API with strict=False. This allows us to do the + validation without any additional calls to .state_dict(), which use additional memory. + This API should only be used for single-device recipes, or on multi-device after + https://github.com/pytorch/pytorch/pull/120600. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether LoRA is applied to each MLP linear. + apply_lora_to_output (bool): whether LoRA is applied to the final output projection. + base_missing (Optional[List[str]]): List of missing keys when loading base model weights. + Default: None + base_unexpected (Optional[List[str]]): List of unexpected keys when loading base model weights. + Default: None + lora_missing (Optional[List[str]]): List of missing keys when loading LoRA weights. + Default: None + lora_unexpected (Optional[List[str]]): List of unexpected keys when loading LoRA weights. + Default: None + + Returns: + None + + Raises: + AssertionError: if base_missing contains any base model keys. + AssertionError: if base_unexpected is nonempty. + AssertionError: if lora_missing contains any LoRA keys. + AssertionError: if lora_unexpected is nonempty. + """ + lora_modules = get_lora_module_names( + lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output + ) + is_lora_param = lambda x: any([".".join([k, "lora"]) in x for k in lora_modules]) + if base_missing: + for k in base_missing: + if not is_lora_param(k): + raise AssertionError(f"Missing non-LoRA key {k} from base model dict") + if base_unexpected: + raise AssertionError("Unexpected key loading base model") + if lora_missing: + for k in lora_missing: + if is_lora_param(k): + raise AssertionError(f"Missing LoRA key {k} from adapter state dict") + if lora_unexpected: + raise AssertionError("Unexpected key loading adapter") diff --git a/training/torchtune/modules/position_embeddings.py b/training/torchtune/modules/position_embeddings.py new file mode 100644 index 0000000..bf94eb0 --- /dev/null +++ b/training/torchtune/modules/position_embeddings.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torch import nn, Tensor + + +class RotaryPositionalEmbeddings(nn.Module): + """ + This class implements Rotary Positional Embeddings (RoPE) + proposed in https://arxiv.org/abs/2104.09864. + + Reference implementation (used for correctness verfication) + can be found here: + https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 + + In this implementation we cache the embeddings for each position upto + ``max_seq_len`` by computing this during init. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ````embed_dim`` // ``num_heads```` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self._rope_init() + + # We need to explicitly define reset_parameters for FSDP initialization, see + # https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885 + def reset_parameters(self): + self._rope_init() + + def _rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [b, s, n_h, h_d] + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with RoPE applied + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + + TODO: The implementation below can be made more efficient + for inference. + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) diff --git a/training/torchtune/modules/rlhf/__init__.py b/training/torchtune/modules/rlhf/__init__.py new file mode 100644 index 0000000..73b5c3d --- /dev/null +++ b/training/torchtune/modules/rlhf/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._generation import ( + generate_next_token_with_logits, + generate_with_logits, + get_causal_mask, +) + +from ._types import PPOStats, Trajectory +from .collate import left_padded_collate, padded_collate_dpo +from .rewards import ( + estimate_advantages, + get_reward_penalty_mask, + get_rewards_ppo, + masked_mean, + masked_var, + whiten, +) +from .sequence_processing import ( + get_batch_log_probs, + logits_to_logprobs, + truncate_sequence_at_first_stop_token, + truncate_sequence_for_logprobs, +) + +__all__ = [ + "generate_with_logits", + "generate_next_token_with_logits", + "truncate_sequence_at_first_stop_token", + "get_causal_mask", + "logits_to_logprobs", + "truncate_sequence_for_logprobs", + "get_reward_penalty_mask", + "left_padded_collate", + "padded_collate_dpo", + "estimate_advantages", + "get_rewards_ppo", + "whiten", + "masked_mean", + "masked_var", + "PPOStats", + "get_batch_log_probs", + "Trajectory", +] diff --git a/training/torchtune/modules/rlhf/_generation.py b/training/torchtune/modules/rlhf/_generation.py new file mode 100644 index 0000000..a7f54b6 --- /dev/null +++ b/training/torchtune/modules/rlhf/_generation.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from torchtune.modules.transformer import TransformerDecoder + + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + """Samples from a multinomial distribution.""" + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: int = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + """Generic sample from a probability distribution.""" + # scale the logits based on temperature + logits = logits / max(temperature, 1e-5) + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # select the very last value from the top_k above as the pivot + pivot = v.select(-1, -1).unsqueeze(-1) + # set everything smaller than pivot value to inf since these + # should be pruned + logits = torch.where(logits < pivot, -float("Inf"), logits) + # change logits into probabilities + probs = torch.nn.functional.softmax(logits, dim=-1) + return multinomial_sample_one(probs, rng) + + +def generate_next_token_with_logits( + model: TransformerDecoder, + input_pos: torch.Tensor, + x: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generates the next tokens given a prompt, and also returns the corresponding logits. + + Args: + model (TransformerDecoder): model used for generation + input_pos (torch.Tensor): tensor with the positional encodings associated with the given prompt, + with shape [bsz x seq_length]. + x (torch.Tensor): tensor with the token IDs associated with the given prompt, + with shape [bsz x seq_length]. + mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length], + default None. + temperature (float): value to scale the predicted logits by, default 1.0. + top_k (Optional[int]): Top-k value to use for sampling, default None. + rng (Optional[torch.Generator]): random number generator, default None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: + - logits (torch.Tensor): tensor with the logits associated with the generated tokens, + with shape [bsz x seq_length x vocab_size]. + - tokens (torch.Tensor): tensor with the generated tokens, + with shape [bsz x 1]. + + """ + # model produces logits in [bsz, seq_length, vocab_size] + # we want to take the last token's logits as the input to the next model call + logits = model(x, input_pos=input_pos, mask=mask) + return logits, sample(logits[:, -1].clone(), temperature, top_k, rng) + + +def get_causal_mask( + padding_mask: torch.Tensor, +) -> torch.Tensor: + """ + Converts an attention mask of shape ``[bsz, seq_len]`` to a causal attention mask suitable for + consumption by :func:`~torch.nn.functional.scaled_dot_product_attention~`. + + HF uses a similar implementation internally, see + https://github.com/huggingface/transformers/blob/a564d10afe1a78c31934f0492422700f61a0ffc0/src/transformers/models/mistral/modeling_mistral.py#L1096 + + Args: + padding_mask (torch.Tensor): Boolean tensor where True indicates participation in attention + with shape [bsz x seq_length] + Returns: + torch.Tensor: Boolean causal mask with shape [bsz x seq_length x seq_length] + """ + _, seq_len = padding_mask.shape + mask = torch.tril( + torch.ones(seq_len, seq_len, device=padding_mask.device, dtype=bool), diagonal=0 + ) + mask = mask & (padding_mask[:, None, :] & padding_mask[:, :, None]) + mask.diagonal(dim1=1, dim2=2)[:] = True + return mask + + +@torch.inference_mode() +def generate_with_logits( + model: TransformerDecoder, + prompt: torch.Tensor, + *, + max_generated_tokens: int, + pad_id: int = 0, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +): + """ + Generates tokens from a model conditioned on a prompt, and also returns logits for the generations. + + Args: + model (TransformerDecoder): model used for generation + prompt (torch.Tensor): tensor with the token IDs associated with the given prompt, + with shape either [seq_length] or [bsz x seq_length]. + max_generated_tokens (int): number of tokens to be generated + pad_id (int): token ID to use for padding, default 0. + temperature (float): value to scale the predicted logits by, default 1.0. + top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities, + default None. + rng (Optional[torch.Generator]): random number generator, default None. + + Examples: + >>> model = torchtune.models.llama3.llama3_8b() + >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() + >>> prompt = [0, 0, 0] + tokenizer("Hi my name is") # substitute 0 with pad_id + >>> rng = torch.Generator() # optionally place on device + >>> rng.manual_seed(42) + >>> output = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0, rng=rng) + >>> print(tokenizer.decode(output[0])) + ?? ?? ?? Hi my name is Jeremy and I'm a friendly language model assistant! + + Returns: + torch.Tensor: Generated tokens. + """ + prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt + + _, prompt_length = prompt.size() + generated_tokens = prompt.clone() + + for i in range(max_generated_tokens): + padding_masks = generated_tokens == pad_id + if padding_masks.any(): + mask = get_causal_mask(~padding_masks) + input_pos = (~padding_masks).cumsum(-1) - (~padding_masks).long() + input_pos = input_pos.to(torch.int) + else: + mask = None + input_pos = torch.arange( + 0, prompt_length + i, device=generated_tokens.device + ) + + logits, tokens = generate_next_token_with_logits( + model, + input_pos=input_pos, + x=generated_tokens, + mask=mask, + temperature=temperature, + top_k=top_k, + rng=rng, + ) + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + + return generated_tokens, logits diff --git a/training/torchtune/modules/rlhf/_types.py b/training/torchtune/modules/rlhf/_types.py new file mode 100644 index 0000000..729a403 --- /dev/null +++ b/training/torchtune/modules/rlhf/_types.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import NamedTuple + +import torch + + +class Trajectory(NamedTuple): + """ + Contains a collection of tensors describing a generated trajectory during RLHF + + Attributes: + query_responses (torch.Tensor): (query, response) pairs + shape [b, context_length + max_generated_tokens] + logprobs (torch.Tensor): log probabilities of the generated responses with shape [b, max_generated_tokens] + ref_logprobs (torch.Tensor): log probabilities of the generated responses using the reference policy + shape [b, max_generated_tokens] + values (torch.Tensor): value estimates of the generated responses with shape [b, max_generated_tokens] + masks (torch.Tensor): attention masks for input ids-generated responses pairs + shape [b, context_length + max_generated_tokens, context_length + max_generated_tokens] + position_ids (torch.Tensor): position IDs for input ids-generated responses pairs + shape [b, context_length + max_generated_tokens] + response_padding_masks (torch.Tensor): padding masks for the truncated and padded generated responses + shape [b, max_generated_tokens] + value_padding_masks (torch.Tensor): padding masks for the values with + shape [b, max_generated_tokens] + value_seq_idxs (torch.Tensor): indexes of the token + after the last valid (non-padding) token in the responses with shape [b] + scores (torch.Tensor): scores from the reward model with shape [b] + seq_lens (torch.Tensor): sequence lengths of truncated generated responses with shape [b] + """ + + query_responses: torch.Tensor + logprobs: torch.Tensor + ref_logprobs: torch.Tensor + values: torch.Tensor + masks: torch.Tensor + position_ids: torch.Tensor + response_padding_masks: torch.Tensor + value_padding_masks: torch.Tensor + value_seq_idxs: torch.Tensor + scores: torch.Tensor + seq_lens: torch.Tensor + + +class PPOStats(NamedTuple): + """ + Contains PPO loss statistics (metrics) + + Attributes: + loss (torch.Tensor): The total PPO loss. + policy_loss (torch.Tensor): The policy function loss. + value_loss (torch.Tensor): The value function loss. + ratios (torch.Tensor): The ratio between the current and old policy probabilities. + clipfrac (torch.Tensor): The fraction of ratios that were clipped. + approx_policy_kls (torch.Tensor): Average estimated KL divergence between the policy before and after the optimisation step. + + """ + + loss: torch.Tensor + policy_loss: torch.Tensor + value_loss: torch.Tensor + ratios: torch.Tensor + clipfrac: torch.Tensor + approx_policy_kls: torch.Tensor diff --git a/training/torchtune/modules/rlhf/collate.py b/training/torchtune/modules/rlhf/collate.py new file mode 100644 index 0000000..fa20beb --- /dev/null +++ b/training/torchtune/modules/rlhf/collate.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX + + +def left_padded_collate( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, +) -> torch.Tensor: + """ + Pads a batch of sequences with left padding to the maximum sequence length in the batch. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing inputs. + padding_idx (int): The padding index. Defaults to 0. + + Returns: + torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len]. + + Example: + >>> padding_idx = -8 + >>> batch = [ + >>> {"tokens": [1, 2] }, + >>> {"tokens": [3] }, + >>> {"tokens": [4, 5, 6, 7]}, + >>> ] + >>> left_padded_collate(batch, padding_idx) + >>> tensor([[-8, -8, 1, 2], + >>> [-8, -8, -8, 3], + >>> [ 4, 5, 6, 7]]) + + """ + pad_toks = pad_sequence( + [torch.tensor(x["tokens"][::-1]) for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + seq_idxs_rev = torch.arange(pad_toks.shape[-1] - 1, -1, -1) + return torch.stack([tok[seq_idxs_rev] for tok in pad_toks]) + + +def padded_collate_dpo( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pad a batch of sequences for Direct Preference Optimization (DPO). + + This function takes a batch of sequences, where each sequence is represented + as a dictionary with multiple key-value pairs. Each key corresponds to a different + sequence component, such as input_ids or labels. + + This will raise: + AssertionError: if the length of chosen_input_ids and rejected_input_ids differ. + AssertionError: if the length of chosen_labels and rejected_labels differ. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries, where each dictionary + represents a sequence with multiple components, 'chosen_input_ids', + 'chosen_labels', 'rejected_input_ids', and 'rejected_labels' are required. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing concatenated and padded + input ids and labels. + + + Example: + >>> batch = [ + >>> {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5], + >>> 'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]}, + >>> {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15], + >>> 'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]}, + >>> ] + >>> padded_collate_dpo(batch) + >>> (tensor([[ 1, 2, 3], + >>> [11, 12, 0], + >>> [ 4, 5, 0], + >>> [13, 14, 15]]), + >>> tensor([[ 6, 7, 8], + >>> [16, 17, -100], + >>> [ 9, 10, -100], + >>> [18, 19, 20]])) + """ + chosen_input_ids = [torch.tensor(ex["chosen_input_ids"]) for ex in batch] + rejected_input_ids = [torch.tensor(ex["rejected_input_ids"]) for ex in batch] + chosen_labels = [torch.tensor(ex["chosen_labels"]) for ex in batch] + rejected_labels = [torch.tensor(ex["rejected_labels"]) for ex in batch] + + assert len(chosen_input_ids) == len(rejected_input_ids) + assert len(chosen_labels) == len(rejected_labels) + + to_pad_input_ids = chosen_input_ids + rejected_input_ids + to_pad_labels = chosen_labels + rejected_labels + + concatenated_input_ids = pad_sequence( + to_pad_input_ids, batch_first=True, padding_value=padding_idx + ) + concatenated_labels = pad_sequence( + to_pad_labels, batch_first=True, padding_value=ignore_idx + ) + + return concatenated_input_ids, concatenated_labels diff --git a/training/torchtune/modules/rlhf/rewards.py b/training/torchtune/modules/rlhf/rewards.py new file mode 100644 index 0000000..0e5994f --- /dev/null +++ b/training/torchtune/modules/rlhf/rewards.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + + +def get_reward_penalty_mask( + padding_masks: torch.Tensor, + seq_lens: torch.Tensor, + penalise_no_eos: bool = True, + min_response_length: int = None, +) -> torch.Tensor: + """ + Calculates a mask to penalise scores corresponding to sequences generated during PPO, where True indicates the score + at the corresponding position should be penalised. + This function assumes sequences have already been truncated at an EOS, if present, and padded to length, + e.g. by :func:`torchtune.modules.rlhf.sequence_processing.truncate_sequence_at_first_stop_token`. + + Scores are penalised such that: + - If ``min_response_length`` is set, scores for sequences with ``length < min_response_length`` are penalised. + - If ``penalise_no_eos`` is True, scores for sequences with no EOS token are penalised. + + Args: + padding_masks (torch.Tensor): Tensor where True indicates a padding token in the generated + sequence, and False otherwise. Shape: (b, reponse_len) + seq_lens (torch.Tensor): The length of each generated sequence. Shape: (b,) + penalise_no_eos (bool, optional): Whether to penalise sequences with no EOS token. Defaults to True. + min_response_length (int, optional): The minimum length of the response. If set, any responses is shorter + than this length will be penalised. Defaults to None. + Returns: + torch.Tensor: A mask tensor with shape (b,) where True indicates the corresponding score should be penalised. + """ + reward_penalty_mask = torch.zeros_like(seq_lens).to(bool) + + # since sequences will have been truncated at EOS, we can mask based on the presence of any padding tokens + if penalise_no_eos: + reward_penalty_mask = ~padding_masks.any(-1) + + if min_response_length is not None: + reward_penalty_mask |= ~(seq_lens >= min_response_length) + return reward_penalty_mask + + +def get_rewards_ppo( + scores: torch.Tensor, + logprobs: torch.Tensor, + ref_logprobs: torch.Tensor, + kl_coeff: float, + valid_score_idxs: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculates PPO rewards for the given scores, logprobs, and reference logprobs. + + Args: + scores (torch.Tensor): Reward model scores, shape (b,). + logprobs (torch.Tensor): Policy logprobs, shape (b, reponse_len). + ref_logprobs (torch.Tensor): Reference base model, shape (b, reponse_len). + kl_coeff (float): KL reward contribution coefficient. + valid_score_idxs (Optional[torch.Tensor]): A tensor of indexes for valid (non-padded) token predictions. + This is useful when calculating rewards for padded sequences, as scores and value estimates are defined + for the last valid predicted token. Shape: (b,). Default None. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of tensors with shape [b, response_len] each: + - total_reward: total reward combining per-token kl rewards and reward model score. + - kl: kl divergence between policy and reference policy logprobs. + - kl_reward: kl divergence scaled by ``kl_coeff``. + + Notation used for tensor shapes: + - b: batch size + - response_len: model response length + """ + + # 1. calculate kl between logprobs and reflogprobs + # 2. calculate kl reward using adaptive scaling value + # 3. calculate total reward by summing above + # return all + kl = logprobs - ref_logprobs + kl_reward = -kl_coeff * kl + + total_reward = kl_reward.clone() + + # adding reward to kl at final valid position + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L153 + + if valid_score_idxs is not None: + total_reward[ + torch.arange(scores.shape[0], device=scores.device), valid_score_idxs + ] += scores + else: + total_reward[:, -1] += scores + + return total_reward, kl, kl_reward + + +def masked_mean( + x: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + """ + Compute mean of tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + + Args: + x (torch.Tensor): The input tensor. + mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the mean calculation. + dim (Optional[int]): The axis to calculate the mean over. Default None. + + Returns: + torch.Tensor: The mean tensor. + """ + return (x * mask).sum(dim=dim) / mask.sum(dim=dim) + + +def masked_var( + x: torch.Tensor, mask: torch.Tensor, unbiased: bool = True +) -> torch.Tensor: + """ + Compute variance of tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + + Args: + x (torch.Tensor): The input tensor. + mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the mean calculation. + unbiased (bool): Whether to use the unbiased variance. + + Returns: + torch.Tensor: The variance tensor. + + Raises: + ValueError: If the sum of the mask is zero. + """ + mean = masked_mean(x, mask) + centered_values = x - mean + var = masked_mean(centered_values.pow(2), mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when ``ppo_batch_size=1``;" + "try increase the ``ppo_batch_size`` or ``gradient_accumulation_steps``" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + var = var * bessel_correction + return var + + +def whiten( + x: torch.Tensor, mask: Optional[torch.Tensor] = None, shift_mean: bool = True +) -> torch.Tensor: + """ + Whiten (normalises) values, optionally with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + Args: + x (torch.Tensor): The input tensor. + mask (Optional[torch.Tensor]): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the mean calculation. Default None. + shift_mean (bool): Whether to shift normalised values by the mean. + + Returns: + torch.Tensor: The whitened tensor. + """ + if mask is not None: + mean = masked_mean(x, mask) + var = masked_var(x, mask) if mask.any() else x.var() + else: + mean, var = x.mean(), x.var() + whitened = (x - mean) * torch.rsqrt(var + 1e-8) + if shift_mean: + whitened += mean + return whitened + + +def estimate_advantages( + values: torch.Tensor, + rewards: torch.Tensor, + gamma: float, + lmbda: float, + masks: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Estimates the advantages and returns for the PPO algorithm using Generalized Advantage Estimation + https://arxiv.org/pdf/1506.02438.pdf + + Args: + values (torch.Tensor): The predicted values for each state. Shape: (b, reponse_len) + rewards (torch.Tensor): The rewards received at each time step. Shape: (b, reponse_len) + gamma (float): The discount factor. + lmbda (float): The GAE-Lambda parameter. + masks (Optional[torch.Tensor]): A bool mask tensor, where True indicates the corresponding value in ``values`` + should participate in the mean calculation. Default None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the estimated advantages and returns. + - advantages (torch.Tensor): The estimated advantages. Shape: (b, reponse_len) + - returns (torch.Tensor): The estimated returns. Shape: (b, reponse_len) + Notation: + - b: batch size + - reponse_len: model response length + """ + + last_gae_lam = 0 + advantages_reversed = [] + + response_length = values.shape[-1] + + # estimate advantage for every predicted token position + for t in reversed(range(response_length)): + # value of the next state + next_values = values[:, t + 1] if t < response_length - 1 else 0.0 + # exponentially discounted temporal difference error: + # delta_t = r_t + gamma * V(s_{t+1}) - V(s_t) + delta = rewards[:, t] + gamma * next_values - values[:, t] + # GAE-Lambda advantage discounting saved for the next iteration + # as A_t = delta_t + gamma * lambda * A_{t+1} + ... + last_gae_lam = delta + gamma * lmbda * last_gae_lam + advantages_reversed.append(last_gae_lam) + + advantages = torch.stack(advantages_reversed[::-1], axis=1) + + # returns are the expected value of taking action a_t at each timepoint over + # a trajectory. the value estimates v_t are the expected value over all actions + # over a trajectory - the advantage is the difference between the two + returns = advantages + values + + # normalize advantages across the batch of trajectories to reduce variance + if masks is not None: + advantages = whiten(advantages, mask=masks) + advantages[~masks] = 0.0 + else: + advantages = whiten(advantages) + + return advantages, returns diff --git a/training/torchtune/modules/rlhf/sequence_processing.py b/training/torchtune/modules/rlhf/sequence_processing.py new file mode 100644 index 0000000..adbce7c --- /dev/null +++ b/training/torchtune/modules/rlhf/sequence_processing.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.modules import rlhf + + +def truncate_sequence_at_first_stop_token( + sequences: torch.Tensor, stop_tokens: torch.Tensor, fill_value: int = 0 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Truncates sequence(s) after the first stop token and pads with ``fill_value``. + + Args: + sequences (torch.Tensor): tensor of shape [batch_size, sequence_length] or [sequence_length]. + stop_tokens (torch.Tensor): tensor containing stop tokens. + fill_value (int): value to pad the sequence with after the first stop token, usually ``pad_id``. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors with the same shape as ``sequences``: + - padding_mask (torch.Tensor): a bool tensor where True indicates the token has been truncated. + - sequences (torch.Tensor) a tensor of truncated and padded sequences. + + Example: + >>> stop_token_ids = torch.tensor([2, 869]) + >>> fill_value = 0 + >>> sequences = torch.tensor( + >>> [ + >>> [869, 30, 869], + >>> [2, 30, 869], + >>> [869, 30, 2], + >>> [50, 30, 869], + >>> [13, 30, 2], + >>> [13, 30, 5], + >>> [13, 2, 20], + >>> [13, 2, 2], + >>> [2, 2, 2], + >>> ] + >>> ) + >>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( + >>> sequences, stop_token_ids, fill_value + >>> ) + >>> eos_mask + >>> torch.tensor([ + >>> [False, True, True], + >>> [False, True, True], + >>> [False, True, True], + >>> [False, False, False], + >>> [False, False, False], + >>> [False, False, False], + >>> [False, False, True], + >>> [False, False, True], + >>> [False, True, True], + >>> ] + >>> ) + >>> truncated_sequences + >>> torch.tensor([ + >>> [869, 0, 0], + >>> [2, 0, 0], + >>> [869, 0, 0], + >>> [50, 30, 869], + >>> [13, 30, 2], + >>> [13, 30, 5], + >>> [13, 2, 0], + >>> [13, 2, 0], + >>> [2, 0, 0], + >>> ] + >>> ) + """ + eos_mask = torch.isin(sequences, stop_tokens) + seq_lens = torch.cumsum(eos_mask, dim=1) + padding_mask = (seq_lens > 1) | ((seq_lens == 1) & ~eos_mask) + sequences[padding_mask] = fill_value + return padding_mask, sequences + + +def logits_to_logprobs( + logits: torch.Tensor, sequences: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + """ + Converts logits corresponding to a generated sequence to logprobs over the generated tokens. + + Args: + logits (torch.Tensor): The logits tensor of shape [b, response_length, vocab_size]. + sequences (torch.Tensor): The corresponding tokens of shape [b, response_length]. + temperature (float): The temperature to scale the logits. Default 1.0 + Returns: + torch.Tensor: The log probabilities corresponding to each token in ``sequences``. Shape [b, response_length]. + """ + return torch.gather( + F.log_softmax(logits / temperature, dim=-1), + 2, + sequences.unsqueeze(-1), + ).squeeze(-1) + + +def get_batch_log_probs( + logits: torch.FloatTensor, + labels: torch.LongTensor, + label_pad_token_id: int = CROSS_ENTROPY_IGNORE_IDX, + return_average_logprobs: bool = False, +) -> torch.FloatTensor: + """ + Calculate log probabilities based on provided logits and labels. + + Args: + logits (torch.FloatTensor): direct logits output of the model of shape (b, s, v) + labels (torch.LongTensor): ground-truth labels to compute log probs with, shape (b, s). + Label tokens with a value of label_pad_token_id are ignored. + label_pad_token_id (int): token id to ignore in labels. + return_average_logprobs (bool): If True, return the average log probs across the sequence. Default + is False. See https://github.com/eric-mitchell/direct-preference-optimization/blob/f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L96 # noqa + + Returns: + Calculated log probs of shape (b, ) + + Raises: + ValueError: If logits and labels have different shapes. + """ + + if logits.shape[:-1] != labels.shape: + raise ValueError( + "Logits (batch and sequence length dim) and labels must have the same shape." + ) + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + labels[labels == label_pad_token_id] = 0 + # take log-likelihood of the labels given our model + per_token_log_probs = logits_to_logprobs(logits, labels, temperature=1.0) + + if return_average_logprobs: + return rlhf.masked_mean(per_token_log_probs, loss_mask, dim=-1) + else: + return (per_token_log_probs * loss_mask).sum(-1) + + +def truncate_sequence_for_logprobs( + query_response_logits: torch.Tensor, context_length: int +) -> torch.Tensor: + """ + Truncates logits generated over a sequence for estimating logprobs over the tokens in the sequence. + This assumes the sequence is of the (query, response) format with length (context_length + response_length) + Args: + query_response_logits (torch.Tensor): The logits tensor of shape [b, context_length + response_length, vocab_size]. + context_length (int): The length of the context. + + Returns: + torch.Tensor: The truncated logits for the response with shape [b, response_length, vocab_size].""" + return query_response_logits[:, context_length - 1 : -1] diff --git a/training/torchtune/modules/rlhf/utils/__init__.py b/training/torchtune/modules/rlhf/utils/__init__.py new file mode 100644 index 0000000..e8fa398 --- /dev/null +++ b/training/torchtune/modules/rlhf/utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._convert_weights import reward_hf_to_tune, reward_tune_to_hf # noqa + +__all__ = [ + "reward_hf_to_tune", + "reward_tune_to_hf", +] diff --git a/training/torchtune/modules/rlhf/utils/_convert_weights.py b/training/torchtune/modules/rlhf/utils/_convert_weights.py new file mode 100644 index 0000000..2e50138 --- /dev/null +++ b/training/torchtune/modules/rlhf/utils/_convert_weights.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +_REWARD = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + "score.weight": "output.weight", +} + + +def reward_hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to torchtune's format, which contains the weights + of a reward model (i.e. a classifier with a single class). + State dicts from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but with a different mapping. + + Eg of HF-format state dict can be found in the ``Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback`` + repo in HF. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + + Returns: + Dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + # ignore output layer bias - these are not used in the reward model + # and some HF pipelines (e.g. TRL) may save them + if key == "score.bias": + continue + # Skip loading the position embeddings + if "rotary_emb.inv_freq" not in key: + new_key = get_mapped_key(key, _REWARD) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + return converted_state_dict + + +def reward_tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Hugging Face's format for a reward model. + + This function takes a state dictionary in torchtune's format, which contains the weights of a reward model + (i.e. a classifier with a single class), and converts it into a format that can be loaded into a Hugging Face model. + The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but with a different mapping. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + num_heads (int, optional): Number of heads in the model. Defaults to 32. + num_kv_heads (int, optional): Number of heads in the key/value projection layers. Defaults to 32. + dim (int, optional): Dimension of the model. Defaults to 4096. + + Returns: + Dict[str, torch.Tensor]: State dict in Hugging Face's format. + + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _REWARD.items()} + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, head_dim // 2, 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/training/torchtune/modules/rms_norm.py b/training/torchtune/modules/rms_norm.py new file mode 100644 index 0000000..a2e4e2a --- /dev/null +++ b/training/torchtune/modules/rms_norm.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torch import nn, Tensor + + +class RMSNorm(nn.Module): + """ + Implements Root Mean Square Normalization introduced in + https://arxiv.org/abs/1910.07467. + + Reference implementation (used for correctness verfication) + can be found here: + https://github.com/facebookresearch/llama/blob/main/llama/model.py + + Args: + dim (int): embedding size + eps (float): small value to avoid division by zero. Default: 1e-6 + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): input tensor to normalize + + Returns: + Tensor: The output tensor after applying RMSNorm. + """ + # computation is in fp32 + x_fp32 = x.float() + x_normed = ( + x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + ).type_as(x) + return x_normed * self.scale diff --git a/training/torchtune/modules/tokenizers/__init__.py b/training/torchtune/modules/tokenizers/__init__.py new file mode 100644 index 0000000..2fecc27 --- /dev/null +++ b/training/torchtune/modules/tokenizers/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._sentencepiece import SentencePieceBaseTokenizer +from ._tiktoken import TikTokenBaseTokenizer +from ._utils import ( + BaseTokenizer, + ModelTokenizer, + parse_hf_tokenizer_json, + tokenize_messages_no_special_tokens, +) + +__all__ = [ + "SentencePieceBaseTokenizer", + "TikTokenBaseTokenizer", + "ModelTokenizer", + "BaseTokenizer", + "tokenize_messages_no_special_tokens", + "parse_hf_tokenizer_json", +] diff --git a/training/torchtune/modules/tokenizers/_sentencepiece.py b/training/torchtune/modules/tokenizers/_sentencepiece.py new file mode 100644 index 0000000..ed157f5 --- /dev/null +++ b/training/torchtune/modules/tokenizers/_sentencepiece.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +from sentencepiece import SentencePieceProcessor +from torchtune.modules.tokenizers._utils import BaseTokenizer + +WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] + + +class SentencePieceBaseTokenizer(BaseTokenizer): + """ + A light-weight wrapper around SentencePieceProcessor that additionally handles + trimming leading whitespaces. + + Args: + path (str): Path to pretrained tokenizer file. + + Examples: + >>> tokenizer = SentencePieceBaseTokenizer("/path/to/spm_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + ): + spm_model = SentencePieceProcessor() + spm_model.load(path) + self.spm_model = spm_model + self.vocab_size = spm_model.vocab_size() + self.bos_id = spm_model.bos_id() + self.eos_id = spm_model.eos_id() + self.pad_id = spm_model.pad_id() + + # If the tokenizer does not encode whitespace, + # then we can more easily split strings + # on whitespace characters and encode them separately. + self.encodes_whitespace = any( + [self.spm_model.encode(c) for c in WHITESPACE_CHARS] + ) + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + trim_leading_whitespace: bool = False, + prefix: Optional[str] = None, + ) -> List[int]: + """Encode text into token IDs. + + Args: + text (str): The input text to be encoded, unbatched. + add_bos (bool): Whether to prepend BOS to the input, defaults to True. + add_eos (bool): Whether to append EOS to the input, defaults to True. + trim_leading_whitespace (bool): Whether to trim leading whitespace from + underlying sentencepiece tokenization. Sentencepiece normally prepends + whitespace to any tokenized text, which can cause differences where + encode(s1) + encode(s2) != encode(s1 + s2) due to leading whitespace + added to s2. This will only trim leading whitespace if the underlying + ``SentencePieceProcessor`` encodes whitespace. Default: False + prefix (Optional[str]): Optional string to encode for trimming leading + whitespaces. Used only if trim_leading_whitespace=True. Default: None + + Returns: + List[int]: The encoded token IDs. + """ + # We typically trim leading whitespace on the next message when + # it is a continuation of the turn (i.e. not the first message) + # or the previous message did not end with a space. This is handled + # by the caller of this method. We only need to trim leading whitespace + # if the underlying SentencePieceProcessor encodes whitespace. + if trim_leading_whitespace and self.encodes_whitespace: + # Can define our own custom prefix depending on vocab if needed + if not hasattr(self, "prefix"): + self.prefix = prefix or "\n" + self.encoded_prefix = self.spm_model.encode( + self.prefix, add_bos=False, add_eos=False + ) + start_idx = len(self.encoded_prefix) + int(add_bos) + return self.spm_model.encode( + self.prefix + text, + add_bos=add_bos, + add_eos=add_eos, + out_type=int, + )[start_idx:] + else: + return self.spm_model.encode( + text, + add_bos=add_bos, + add_eos=add_eos, + out_type=int, + ) + + def decode(self, ids: List[int]) -> str: + """Decode token IDs to strings. + + Args: + ids (List[int]): The input token IDs to be decoded. + + Returns: + str: The decoded text. + """ + return self.spm_model.decode(ids) diff --git a/training/torchtune/modules/tokenizers/_tiktoken.py b/training/torchtune/modules/tokenizers/_tiktoken.py new file mode 100644 index 0000000..7fe29b2 --- /dev/null +++ b/training/torchtune/modules/tokenizers/_tiktoken.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Iterator, List + +from tiktoken import Encoding +from tiktoken.load import load_tiktoken_bpe +from torchtune.modules.tokenizers._utils import BaseTokenizer + +# Constants controlling encode logic +MAX_ENCODE_CHARS = 400_000 +MAX_NO_WHITESPACE_CHARS = 25_000 + + +class TikTokenBaseTokenizer(BaseTokenizer): + """ + A lightweight wrapper around tiktoken Encoding. This class additionally handles + breaking up the input text into substrings of a max length and splitting up long + repetitions to improve encode speed. + + Args: + path (str): Path to pretrained tokenizer checkpoint file. + name (str): Name of the tokenizer (used by tiktoken for identification). + pattern (str): Regex pattern used to split input text into chunks before passing + to byte-pair encoding. + bos_id (int): beginning-of-sequence token id. This can be present or absent in ``special_tokens``. + eos_id (int): end-of-sequence token id. This can be present or absent in ``special_tokens``. + special_tokens (Dict[str, int]): Mapping of special tokens to their ids. + + Examples: + >>> tokenizer = TikTokenBaseTokenizer("/path/to/tt_model") + >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) + >>> print(tokenized_text) + [1, 31587, 29644, 102, 2] + """ + + def __init__( + self, + path: str, + name: str, + pattern: str, + bos_id: int, + eos_id: int, + special_tokens: Dict[str, int], + ): + mergeable_ranks = load_tiktoken_bpe(path) + self.tt_model = Encoding( + name=name, + pat_str=pattern, + mergeable_ranks=mergeable_ranks, + special_tokens=special_tokens, + ) + # Vocab size without special tokens + self.base_vocab_size = len(mergeable_ranks) + # Vocab size with special tokens + self.vocab_size = self.tt_model.n_vocab + self.bos_id = bos_id + self.eos_id = eos_id + + def _split_long_repetitions( + self, s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Split the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + ) -> List[int]: + """ + Encode a string into a list of token ids. Assumes that the string + contains no special tokens. + + Args: + text (str): The string to encode. + add_bos (bool): Whether to add the tokenizer's bos_id to the encoded string. + Default True. + add_eos (bool): Whether to add the tokenizer's eos_id to the encoded string. + Default True. + + Returns: + List[int]: The list of token ids. + """ + substrs: List[str] = [] + tokens = [] + if not text: + return [] + for i in range(0, len(text), MAX_ENCODE_CHARS): + substr = text[i : i + MAX_ENCODE_CHARS] + # See https://github.com/openai/tiktoken/issues/195 + sliced_substr = self._split_long_repetitions( + substr, MAX_NO_WHITESPACE_CHARS + ) + substrs.extend(sliced_substr) + for substr in substrs: + # allowed_special and disallowed_special are used by tiktoken to define + # how special tokens are encoded. Our setting here is to encode any + # special token as regular text and prevent tiktoken from raising errors. + # This means we should only call encode on strings not containing special tokens. + tokens.extend( + self.tt_model.encode( + substr, + allowed_special=set(), + disallowed_special=(), + ) + ) + if add_bos: + tokens = [self.bos_id] + tokens + if add_eos: + tokens = tokens + [self.eos_id] + return tokens + + def decode( + self, + token_ids: List[int], + truncate_at_eos: bool = True, + skip_special_tokens: bool = True, + ) -> str: + """ + Decode a list of token ids into a string. + + Args: + token_ids (List[int]): The list of token ids. + truncate_at_eos (bool): Whether to truncate the string at the end of + sequence token. Default is True. + skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string. + Default is True. + + Returns: + str: The decoded string. + """ + if truncate_at_eos: + try: + k = token_ids.index(self.eos_id) + except ValueError: + k = None + if k: + token_ids = token_ids[:k] + if skip_special_tokens: + token_ids = [ + token_id + for token_id in token_ids + if token_id not in self.tt_model._special_tokens.values() + and token_id != self.bos_id + ] + return self.tt_model.decode(token_ids) diff --git a/training/torchtune/modules/tokenizers/_utils.py b/training/torchtune/modules/tokenizers/_utils.py new file mode 100644 index 0000000..a658008 --- /dev/null +++ b/training/torchtune/modules/tokenizers/_utils.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from typing import Any, Dict, List, Optional, Protocol, Tuple + +from torchtune.data._messages import Message +from torchtune.data._utils import truncate + + +class BaseTokenizer(Protocol): + """ + Abstract token encoding model that implements ``encode`` and ``decode`` methods. + """ + + def encode(self, text: str, **kwargs: Dict[str, Any]) -> List[int]: + """ + Given a string, return the encoded list of token ids. + + Args: + text (str): The text to encode. + **kwargs (Dict[str, Any]): kwargs. + + Returns: + List[int]: The encoded list of token ids. + """ + pass + + def decode(self, token_ids: List[int], **kwargs: Dict[str, Any]) -> str: + """ + Given a list of token ids, return the decoded text, optionally including special tokens. + + Args: + token_ids (List[int]): The list of token ids to decode. + **kwargs (Dict[str, Any]): kwargs. + + Returns: + str: The decoded text. + """ + pass + + +class ModelTokenizer(Protocol): + """ + Abstract tokenizer that implements model specific special token logic in + the ``tokenize_messages`` method. + """ + + special_tokens: Dict[str, int] + max_seq_len: Optional[int] + + def tokenize_messages( + self, messages: List[Message], **kwargs: Dict[str, Any] + ) -> Tuple[List[int], List[bool]]: + """ + Given a list of messages, return a list of tokens and list of masks for + the concatenated and formatted messages. + + Args: + messages (List[Message]): The list of messages to tokenize. + **kwargs (Dict[str, Any]): kwargs. + + Returns: + Tuple[List[int], List[bool]]: The list of token ids and the list of masks. + """ + pass + + +def tokenize_messages_no_special_tokens( + tokenizer: ModelTokenizer, + messages: List[Message], + bos_id: int, + eos_id: int, + max_seq_len: Optional[int] = None, +) -> Tuple[List[int], List[bool]]: + r"""Tokenize a list of messages one at a time then concatenate them, + returning a list of tokens and a list of masks. Does not add any special + tokens except for BOS and EOS. This serves as a common starting point for + model tokenizers that do not rely heavily on special tokens. + + Examples: + >>> messages = [ + ... Message(role="system", content="system message\n", masked=True), + ... Message(role="user", content="user prompt\n", masked=True), + ... Message(role="assistant", content="assistant response\n"), + ... ] + # tokenize_messages encodes messages separately and concats + >>> tokens = tokenize_messages_no_special_tokens( + ... tokenizer, + ... messages, + ... tokenizer.bos_id, + ... tokenizer.eos_id, + ... max_seq_len + ... )[0] + >>> print(tokens) + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + # Same result as encoding the full string in one go + >>> print(tokenizer.encode(''.join([message.content for message in messages]))) + [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] + + + Args: + tokenizer (ModelTokenizer): Tokenizer to encode messages with. + messages (List[Message]): A list of messages, each containing role, content, + and masked attributes. + bos_id (int): Beggining-of-sequence token id. + eos_id (int): End-of-sequence token id. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None + + Returns: + Tuple[List[int], List[bool]]: The tokenized messages. + + Raises: + RuntimeError: if any message in ``messages`` does not satisfy ``message['type'] == 'text'``. + """ + start_of_turn = True + end_of_turn = False + prev_ends_with_space = False + tokenized_messages = [] + mask = [] + for message in messages: + # If assistant message, this is the end of a turn + end_of_turn = message.role == "assistant" + + # Prepend BOS on start of new turns + if start_of_turn: + tokenized_messages.append(bos_id) + mask.append(message.masked) + + # We want to trim leading whitespace on the next message when + # (a) it is a continuation of the turn (i.e. not the first message) + # (b) the vocabulary explicitly encodes whitespace characters (checked inside + # the base tokenizer's encode method), and + # (c) the previous message did not end with a space + trim_leading_whitespace = (not start_of_turn) and not prev_ends_with_space + + # Tokenize current message, append with masks + tokens = [] + for item in message.content: + if item["type"] == "text": + tokens = tokens + tokenizer.encode( + item["content"].rstrip(" "), + add_bos=False, + add_eos=False, + trim_leading_whitespace=trim_leading_whitespace, + ) + else: + raise RuntimeError(f"Unsupported message content type: {item['type']}") + prev_ends_with_space = item["content"].endswith(" ") + tokenized_messages.extend(tokens) + mask.extend([message.masked] * len(tokens)) + + # If assistant message, append EOS at end + if end_of_turn: + tokenized_messages.append(eos_id) + mask.append(message.masked) + end_of_turn = False + start_of_turn = True + else: + start_of_turn = False + + # Break out early if we reach max_seq_len + if max_seq_len and len(tokenized_messages) >= max_seq_len: + break + + # Finally, truncate if necessary + if max_seq_len: + tokenized_messages = truncate(tokenized_messages, max_seq_len, eos_id) + mask = truncate(mask, max_seq_len, message.masked) + + return tokenized_messages, mask + + +def parse_hf_tokenizer_json(tokenizer_json_path: str) -> Dict[str, int]: + """ + Parse the ``tokenizer.json`` file from a Hugging Face model to extract the + special token str to id mapping. + + Args: + tokenizer_json_path (str): Path to the ``tokenizer.json`` file. + + Returns: + Dict[str, int]: The special token str to id mapping. + """ + with open(tokenizer_json_path, "r") as f: + tokenizer_json = json.load(f) + + return {token["content"]: token["id"] for token in tokenizer_json["added_tokens"]} diff --git a/training/torchtune/modules/transformer.py b/training/torchtune/modules/transformer.py new file mode 100644 index 0000000..e536d2c --- /dev/null +++ b/training/torchtune/modules/transformer.py @@ -0,0 +1,408 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from torchtune.modules import CausalSelfAttention + + +class TransformerDecoderLayer(nn.Module): + """Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer. + + Args: + attn (CausalSelfAttention): Attention module. + mlp (nn.Module): Feed-forward module. + sa_norm (nn.Module): Normalization to be applied before self-attention. + mlp_norm (nn.Module): Normalization to be applied before the feed-forward layer. + """ + + def __init__( + self, + attn: CausalSelfAttention, + mlp: nn.Module, + sa_norm: nn.Module, + mlp_norm: nn.Module, + ) -> None: + super().__init__() + self.sa_norm = sa_norm + self.attn = attn + self.mlp_norm = mlp_norm + self.mlp = mlp + + def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + self.attn.setup_cache(batch_size, dtype) + + @property + def cache_enabled(self) -> bool: + """Check if the key value caches are setup.""" + return self.attn.kv_cache is not None + + def reset_cache(self): + """Reset the key value caches.""" + self.attn.reset_cache() + + def forward( + self, + x: Tensor, + *, + mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [batch_size x seq_length x embed_dim] + mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + with shape [batch_size x seq_length x seq_length]. This is applied after + the query-key multiplication and before the softmax. A value of True in row i + and column j means token i attends to token j. A value of False means token i + does not attend to token j. If no mask is specified, a causal mask + is used by default. Default is None. + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with same shape as input + [batch_size x seq_length x embed_dim] + + TODO: + - Make position of norm configurable + """ + # Input tensor and attention output have the same shape + # [b, s, d] + # Norm applied before self-attention + attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos) + + # Residual connection; shape: [batch_size, seq_length, embed_dim] + h = attn_out + x + + # Norm applied before the feedforward layer + mlp_out = self.mlp(self.mlp_norm(h)) + + # Residual connection; shape: [batch_size, seq_length, embed_dim] + out = h + mlp_out + return out + + +def _get_clones(module: nn.Module, n: int) -> nn.ModuleList: + """ + Return a list of ``n`` identical layers. + + Args: + module (nn.Module): module to be cloned + n (int): number of clones + + Returns: + nn.ModuleList: list of ``n`` identical layers + """ + # FIXME: copy.deepcopy() is not defined on nn.module + return nn.ModuleList([copy.deepcopy(module) for i in range(n)]) + + +class TransformerDecoder(nn.Module): + """ + Transformer Decoder derived from the Llama2 architecture. + + Args: + tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move + tokens to an embedding space. + layer (TransformerDecoderLayer): Transformer Decoder layer. + num_layers (int): Number of Transformer Decoder layers. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value. This is used to setup the + :func:`~torchtune.modules.KVCache` + head_dim (int): embedding dimension for each head in self-attention. This is used + to setup the :func:`~torchtune.modules.KVCache` + norm (nn.Module): Callable that applies normalization to the output of the decoder, + before final MLP. + output (nn.Linear): Callable that applies a linear transformation to the output of + the decoder. + + Note: + Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) + in the module where they are used. This helps reduces the number of raise + statements in code and improves readability. + """ + + def __init__( + self, + tok_embeddings: nn.Embedding, + layer: TransformerDecoderLayer, + num_layers: int, + max_seq_len: int, + num_heads: int, + head_dim: int, + norm: nn.Module, + output: nn.Linear, + ) -> None: + super().__init__() + + self.tok_embeddings = tok_embeddings + self.layers = _get_clones(layer, num_layers) + self.norm = norm + self.output = output + self.max_seq_len = max_seq_len + self.num_heads = num_heads + self.head_dim = head_dim + self.causal_mask = None + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + for layer in self.layers: + layer.setup_cache(batch_size, dtype) + + # causal_mask is used during inference to ensure we're attending + # to the right tokens + self.causal_mask = torch.tril( + torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) + ) + + def caches_are_enabled(self) -> bool: + """Check if the key value caches are setup.""" + return self.layers[0].cache_enabled + + def reset_caches(self): + """Reset the key value caches.""" + if not self.caches_are_enabled(): + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + + for layer in self.layers: + layer.reset_cache() + + def forward( + self, + tokens: Tensor, + *, + mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + tokens (Tensor): input tensor with shape [b x s] + mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + with shape [b x s x s]. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + Tensor: output tensor with shape [b x s x v] + + Raises: + ValueError: if causal_mask is set but input_pos is None + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - v: vocab size + - d: embed dim + - m_s: max seq len + """ + # shape: [b, s, d] + h = self.tok_embeddings(tokens) + + if self.causal_mask is not None: + if input_pos is None: + raise ValueError( + "Caches are setup, but the position of input token is missing" + ) + if mask is not None: + raise ValueError( + "An attention mask was set. Cannot use a non-causal mask for inference" + ) + # shape: [1, input_pos_len, m_s] + # in most cases input_pos_len should be 1 + mask = self.causal_mask[None, input_pos] + + for layer in self.layers: + # shape: [b, s, d] + h = layer(h, mask=mask, input_pos=input_pos) + + # shape: [b, s, d] + h = self.norm(h) + + # shape: [b, s, out_dim] - out_dim is usually the vocab size + output = self.output(h).float() + return output + + +class TiedEmbeddingTransformerDecoder(nn.Module): + """ + Transformer Decoder with tied embedding weight. A key difference between + this class and :class:`~torchtune.modules.TransformerDecoder` + is that the output projection is replaced with token embeddings weights. + + Args: + tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move + tokens to an embedding space. + layer (TransformerDecoderLayer): Transformer Decoder layer. + num_layers (int): Number of Transformer Decoder layers. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value. This is used to setup the + :func:`~torchtune.modules.KVCache` + head_dim (int): embedding dimension for each head in self-attention. This is used + to setup the :func:`~torchtune.modules.KVCache` + norm (nn.Module): Callable that applies normalization to the output of the decoder, + before final MLP. + + Note: + Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) + in the module where they are used. This helps reduces the number of raise + statements in code and improves readability. + """ + + def __init__( + self, + tok_embeddings: nn.Embedding, + layer: TransformerDecoderLayer, + num_layers: int, + max_seq_len: int, + num_heads: int, + head_dim: int, + norm: nn.Module, + ) -> None: + super().__init__() + + self.tok_embeddings = tok_embeddings + self.layers = _get_clones(layer, num_layers) + self.norm = norm + self.max_seq_len = max_seq_len + self.num_heads = num_heads + self.head_dim = head_dim + self.causal_mask = None + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + for layer in self.layers: + layer.setup_cache(batch_size, dtype) + + # causal_mask is used during inference to ensure we're attending + # to the right tokens + self.causal_mask = torch.tril( + torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) + ) + + def caches_are_enabled(self) -> bool: + """Check if the key value caches are setup.""" + return self.layers[0].cache_enabled + + def reset_caches(self): + """Reset the key value caches.""" + if not self.caches_are_enabled(): + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + + for layer in self.layers: + layer.reset_cache() + + def forward( + self, + tokens: Tensor, + *, + mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + tokens (Tensor): input tensor with shape [b x s] + mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + with shape [b x s x s]. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + Tensor: output tensor with shape [b x s x v] + + Raises: + ValueError: if causal_mask is set but input_pos is None + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - v: vocab size + - d: embed dim + - m_s: max seq len + """ + # input tensor of shape [b, s] + bsz, seq_len = tokens.shape + + # shape: [b, s, d] + h = self.tok_embeddings(tokens) + + if self.causal_mask is not None: + if input_pos is None: + raise ValueError( + "Caches are setup, but the position of input token is missing" + ) + if mask is not None: + raise ValueError( + "An attention mask was set. Cannot use a non-causal mask for inference" + ) + # shape: [1, input_pos_len, m_s] + # in most cases input_pos_len should be 1 + mask = self.causal_mask[None, input_pos] + + for layer in self.layers: + # shape: [b, s, d] + h = layer(h, mask=mask, input_pos=input_pos) + + # shape: [b, s, d] + h = self.norm(h) + + # shape: [b, s, out_dim] - out_dim is usually the vocab size + output = F.linear(h, self.tok_embeddings.weight).float() + return output diff --git a/training/torchtune/modules/transforms/__init__.py b/training/torchtune/modules/transforms/__init__.py new file mode 100644 index 0000000..c317e7d --- /dev/null +++ b/training/torchtune/modules/transforms/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask +from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( # noqa + find_supported_resolutions, + get_canvas_best_fit, +) +from torchtune.modules.transforms.vision_utils.resize_with_pad import ( # noqa + resize_with_pad, +) +from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop # noqa + +__all__ = [ + "Transform", + "get_canvas_best_fit", + "resize_with_pad", + "tile_crop", + "find_supported_resolutions", + "VisionCrossAttentionMask", +] diff --git a/training/torchtune/modules/transforms/_transforms.py b/training/torchtune/modules/transforms/_transforms.py new file mode 100644 index 0000000..3121fa6 --- /dev/null +++ b/training/torchtune/modules/transforms/_transforms.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Protocol + +import torch + + +class Transform(Protocol): + """ + Loose interface for all data and model transforms. Transforms operate at the + sample level and perform operations on a sample dict, returning the updated dict. + """ + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + pass + + +class VisionCrossAttentionMask(Transform): + """ + Computes the cross-attention mask for text + image inputs. Text tokens that + participate in cross-attention with an image token will show True in the mask + and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper + (https://arxiv.org/pdf/2204.14198): + + (1) Text tokens immediately following the image token up until the next image token + (2) Consecutive image tokens attend to subsequent text tokens + + :: + + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + These are two dogs. This is a cat. + + + + Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), + where True indicates that the token outputted from the image encoder attends + to the token in the text sequence in cross-attention. A list of these masks + are returned with length equal to number of images in the sample. + + Args: + tile_size (int): The size of the image tiles from the image transform + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + image_token_id (int): Token ID of the image special token. + """ + + def __init__(self, tile_size: int, patch_size: int, image_token_id: int): + patch_grid_size = tile_size // patch_size + self.patches_per_tile = patch_grid_size**2 + self.image_token_id = image_token_id + + def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: + """ + Returns a list of lists of the form [start, end) where start is the index + of the current image token and end is the index of the next image token, exclusive. + + Args: + tokens (List[int]): List of token IDs in the text sequence + + Returns: + List[List[int]]: List of lists of the form [start, end) indicating + range of positions in text sequence that should attend to the image + + Example: + >>> text = "These are two dogs. This is a cat." + >>> image_token_id = 1 + >>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] + >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) + >>> intervals = transform._get_image_attention_intervals(tokens) + >>> print(intervals) + [[0, 7], [1, 7], [7, 12]] + """ + end = len(tokens) + vision_token_locations = [ + i for i, token in enumerate(tokens) if token == self.image_token_id + ] + # Return empty list if there are no images + if len(vision_token_locations) == 0: + return [] + # If there is only one image, it will attend to subsequent text until end + if len(vision_token_locations) == 1: + return [[vision_token_locations[0], end]] + + # Construct intervals from previous image token to next image token + vision_masks = [ + [tok_idx_prev, tok_idx_next] + # Offset by one to get consecutive indices + for tok_idx_prev, tok_idx_next in zip( + vision_token_locations[:-1], vision_token_locations[1:] + ) + ] + # Last image will attend to subsequent text until end + vision_masks.append([vision_token_locations[-1], end]) + + # If there are consecutive vision tokens, they should all attend to the + # same subsequent text + last_mask_end = vision_masks[-1][1] + for vision_mask in vision_masks[::-1]: + if vision_mask[0] == vision_mask[1] - 1: + vision_mask[1] = last_mask_end + last_mask_end = vision_mask[1] + return vision_masks + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Generates the vision cross-attention mask for the given sample based on + the image token locations interleaved in the text sequence. + + Args: + sample (Mapping[str, Any]): Sample dict containing the following keys: + - tokens (List[int]): List of token IDs in the text sequence. Number of + image token IDs in the sequence must match the number of images. + - images (List[torch.Tensor]): List of image Tensors post-tiling of shape + (n_tiles, c, h, w) each. + + Returns: + Mapping[str, Any]: updated sample with the following keys: + - encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len), + where length of list == number of images in sample + - tokens (List[int]): original tokens + - images (List[torch.Tensor]): original images + + Raises: + RuntimeError: if the number of images in the batch does not match the number of image tokens in the batch. + """ + tokens, images = sample["tokens"], sample["images"] + # One sample can have multiple images - verify the number of image tokens + # is the same + n_img = len(images) + intervals = self._get_image_attention_intervals(tokens) + if len(intervals) != n_img: + raise RuntimeError( + f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." + ) + + # Create mask for each individual image based on its number of tokens, + # which can vary based on number of tiles since they are not yet tile padded. + # The masks are padded and concatenated together in the batch collator + text_seq_len = len(tokens) + masks = [] + for image_num, interval in enumerate(intervals): + # Identify what part of text sequence should be attended + start, end = interval + # Compute this image's number of tokens based on num tiles, patches per tile + n_tiles = images[image_num].shape[0] + image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token + # Mask will be block of 1s at the corresponding interval in the text. + # It is not a causal block because all the image tokens correspond + # to a single image, so text tokens attend to all the image's tokens + mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) + mask[start:end, :] = True + masks.append(mask) + + sample.update({"encoder_mask": masks}) + return sample diff --git a/training/torchtune/modules/transforms/vision_utils/__init__.py b/training/torchtune/modules/transforms/vision_utils/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/training/torchtune/modules/transforms/vision_utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/training/torchtune/modules/transforms/vision_utils/get_canvas_best_fit.py b/training/torchtune/modules/transforms/vision_utils/get_canvas_best_fit.py new file mode 100644 index 0000000..77eabc5 --- /dev/null +++ b/training/torchtune/modules/transforms/vision_utils/get_canvas_best_fit.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from collections import defaultdict +from typing import List, Set, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def get_canvas_best_fit( + image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool +) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to + resize an image to, without distortion. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x, + then the maximum upscaling without distortion is min(2, 1.5) = 1.5. + + If there are multiple canvases that satisfy the conditions, + we pick the one with the lowest area to minimize padding. + + Args: + image (torch.Tensor): The image we want to fit into a canvas. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible canvas. + resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling. + If False, pick the canvas that minimizes downscaling, including no downscaling at all. + + Returns: + Tuple[int, int]: The best resolution to fit the image into. + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> possible_resolutions = torch.tensor([ + ... [224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224] + ... ]) + >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False) + (224, 448) + + In the example above, we calculate the scaling factors for each possible resolution + + >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + + Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest + + >>> upscaling_options = torch.tensor([1.1200, 1.1200]) + >>> selected_scale = torch.tensor(1.1200) + + There are two possible options, so we pick the one with the smallest area + + >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively + >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area + """ + + original_height, original_width = image.shape[-2:] + + # possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get limiting side scaling -> no distortion + scales = torch.where(scale_w > scale_h, scale_h, scale_w) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + +def find_supported_resolutions( + max_num_tiles: int, tile_size: int +) -> List[Tuple[int, int]]: + """ + Computes all combinations of resolutions, multiple of tile_size, + that contain up to max_num_tiles. Useful for when dividing an image into tiles. + + For example, if we want at most 2 tiles per image, then we can support the + following resolutions: (1x1, 1x2, 2x1) * tile_size + + Args: + max_num_tiles (int): Maximum number of tiles. + tile_size (int): Size of the side of the tile. + + Returns: + List[Tuple[int, int]]: List of possible resolutions as tuples (height, width). + + Examples: + + >>> max_num_tiles = 4 + >>> tile_size = 224 + >>> find_supported_resolutions(max_num_tiles, tile_size) + [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)] + """ + + # create dictionary {aspect_ratio: [resolution1, ..., resolution n]} + # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]} + asp_dict = defaultdict(list) + for _tile_size in range(max_num_tiles, 0, -1): + factors = sorted(_get_factors(_tile_size)) + asp_ratios = [(factor, _tile_size // factor) for factor in factors] + for height, width in asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the tile_size + possible_resolutions = [] + for ar, resolution in asp_dict.items(): + for height, width in resolution: + possible_resolutions.append((height * tile_size, width * tile_size)) + + return possible_resolutions + + +def _get_factors(n: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a divisor that leaves no remainder. + + Args: + n (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + + Examples: + >>> _get_factors(n=12) + {1, 2, 3, 4, 6, 12} + """ + factors_set = set() + + for i in range(1, int(n**0.5) + 1): + if n % i == 0: + factors_set.add(i) + factors_set.add(n // i) + return factors_set diff --git a/training/torchtune/modules/transforms/vision_utils/resize_with_pad.py b/training/torchtune/modules/transforms/vision_utils/resize_with_pad.py new file mode 100644 index 0000000..1092557 --- /dev/null +++ b/training/torchtune/modules/transforms/vision_utils/resize_with_pad.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import math +from typing import Optional, Tuple + +import torch + +import torchvision +from torchvision.transforms.v2 import functional as F + +logger = logging.getLogger(__name__) + + +def resize_with_pad( + image: torch.Tensor, + target_size: Tuple[int, int], + resample: torchvision.transforms.InterpolationMode, + max_upscaling_size: Optional[int] = None, +) -> torch.Tensor: + """ + Resizes and pads an image to target_size without causing distortion. + The user can set max_upscaling_size to limit upscaling when target_size exceeds image_size. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images. + Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT, + InterpolationMode.BILINEAR and InterpolationMode.BICUBIC. + max_upscaling_size (Optional[int]): The maximum size to upscale the image to. + If None, will upscale up to target_size. + + Returns: + torch.Tensor: The resized and padded image tensor in the format [..., H, W]. + + Examples: + + Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side, + and then padded from (448, 1194) to (448, 1344). + + >>> max_upscaling_size = None + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_upscaling_size) + + Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344). + + >>> max_upscaling_size = 600 + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_upscaling_size) + + Example 3: The image will be downscaled from (500, 1000) to (224, 448), + and padded from (224, 448) to (448, 448). + + >>> max_upscaling_size = 600 + >>> image = torch.rand([3, 500, 1000]) + >>> target_size = (448, 488) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_upscaling_size) + + """ + + image_height, image_width = image.shape[-2:] + image_size = (image_height, image_width) + + # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size + if max_upscaling_size is not None: + new_target_height = min(max(image_height, max_upscaling_size), target_size[0]) + new_target_width = min(max(image_width, max_upscaling_size), target_size[1]) + target_size_resize = (new_target_height, new_target_width) + else: + target_size_resize = target_size + + # resize to target_size while preserving aspect ratio + new_size_preserving_aspect_ratio = _get_max_res_without_distortion( + image_size=image_size, + target_size=target_size_resize, + ) + + image = F.resize( + inpt=image, + size=list(new_size_preserving_aspect_ratio), + interpolation=resample, + antialias=True, + ) + + image = _pad_image_top_left(image=image, target_size=target_size) + + return image + + +def _pad_image_top_left( + image: torch.Tensor, + target_size: Tuple[int, int], +) -> torch.Tensor: + """ + Places the image at the top left of the canvas and pads with 0 the right and bottom + to fit to the target resolution. If target_size < image_size, it will crop the image. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + + Returns: + torch.Tensor: The padded image tensor in the format [..., H, W]. + """ + + image_size = image.shape[-2:] + + height, width = image_size + target_height, target_width = target_size + + pad_x = target_width - width + pad_y = target_height - height + + padding = [0, 0, pad_x, pad_y] + return F.pad(inpt=image, padding=padding) + + +def _get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], +) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + For example, if image_size = (200,400) and target_size = (600,800), + scale_h = 600/200 = 3 + scale_w = 800/400 = 2 + So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2 + + Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w + + Args: + image_size (Tuple[int, int]): The original resolution of the image. + target_size (Tuple[int, int]): The desired resolution to fit the image into. + Returns: + Tuple[int, int]: The optimal dimensions to which the image should be resized. + Examples: + >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200)) + (133, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300)) + (450, 337) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width diff --git a/training/torchtune/modules/transforms/vision_utils/tile_crop.py b/training/torchtune/modules/transforms/vision_utils/tile_crop.py new file mode 100644 index 0000000..17e173c --- /dev/null +++ b/training/torchtune/modules/transforms/vision_utils/tile_crop.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: + """ + Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size. + + Args: + image (torch.Tensor): Input image to crop into tiles. + tile_size (int): Size of each tile. + + Returns: + torch.Tensor: Tensor of shape [num_tiles, channel_size, tile_size, tile_size] + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> tiles = tile_crop(image, tile_size=50) + >>> tiles.shape # 4x6 = 24 tiles + torch.Size([24, 3, 50, 50]) + + >>> image = torch.rand(3, 400, 600) + >>> tiles = tile_crop(image, tile_size=200) + >>> tiles.shape # 2x3 = 6 tiles + torch.Size([6, 3, 200, 200]) + """ + + channel_size, height, width = image.shape + + # assert sizes are divisible + assert ( + height % tile_size == 0 and width % tile_size == 0 + ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" + + # Reshape to split height and width into tile_size blocks + tiles_height = height // tile_size + tiles_width = width // tile_size + + reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size) + + # Transpose to bring tiles together + # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] + transposed = reshaped.permute(1, 3, 0, 2, 4) + + # Flatten the tiles + tiles = transposed.contiguous().view( + tiles_height * tiles_width, channel_size, tile_size, tile_size + ) + + return tiles diff --git a/training/torchtune/modules/vision_transformer.py b/training/torchtune/modules/vision_transformer.py new file mode 100644 index 0000000..51801a1 --- /dev/null +++ b/training/torchtune/modules/vision_transformer.py @@ -0,0 +1,458 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from torchtune.modules import Fp32LayerNorm +from torchtune.modules.transformer import _get_clones + + +class VisionTransformer(nn.Module): + """ + Implementation of the ViT architecture (https://arxiv.org/abs/2010.11929), + with support for tile-cropped images, outputting of hidden layers and optional CLS projection. + + ViT is a transformer architecture that takes in images and outputs N embedded tokens that + represent this image. Each image is divided into **patches** by a convolution. + These patches are flattened and subsequently treated as **tokens** by the transformer. + + To further enhance the performance of ViT and avoid downscaling images, we support tile-cropped images, + which are images divided into **tiles** during the preprocessing stage. For example, instead of + downscaling an 800x400 image to fit 400x400, we may crop it into two 400x400 tiles, + if the ``tile_size=400``. For details on preprocessing, please refer to + :class:`torchtune.models.clip._transforms.CLIPImageTransform`. + + Each of these tiles is further broken down into patches by a convolution operation. For example, if + your ``patch_size=40``, then each (400, 400) tile will become a grid of 10x10 patches, and your whole image will have + num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101. + + Before the transformer layers, a CLS token is added to each tile as the first token. + In transformers, a token called CLS is a special token that is added to the beginning of each sequence. + This token can be used to represent the whole input, instead of using a pooling operation, for example. + + To help the model "see" the whole image, we use positional embeddings. If your image + was tile-cropped, then you need to use tile positional embeddings: + + - token_pos_embedding (tiled): :class:`torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding` + - pre_tile_pos_embed: :class:`torchtune.models.clip._position_embeddings.TilePositionalEmbedding` + - post_tile_pos_embed: :class:`torchtune.models.clip._position_embeddings.TilePositionalEmbedding` + + Otherwise, pre and post tile_pos_embed should be None and all you need is a simple + token positional embedding: + + - token_pos_embedding (not tiled): :class:`torchtune.models.clip._position_embeddings.TokenPositionalEmbedding` + + All images will be considered as a stack of tiles, even if your image was not tile-cropped. In such cases, + your image would be composed of a single tile. + + In summary: + + 1) An image is broken down into tiles during preprocessing. + 2) In the ViT, the tiles will be broken down into patches. + 3) The patches will be flattened and transformed. We call them tokens, because that's how the transformer sees them. + + + Image: shape (8x8) + + .. code-block:: text + + | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | + | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | + | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | + | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | + | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | + | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | + | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | + | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | + + Tiles: shape (4,4,4) # (num_tiles, tile_size, tile_size) + + .. code-block:: text + + | 1 | 2 | 3 | 4 | | 5 | 6 | 7 | 8 | + | 9 | 10 | 11 | 12 | | 13 | 14 | 15 | 16 | + | 17 | 18 | 19 | 20 | | 21 | 22 | 23 | 24 | + | 25 | 26 | 27 | 28 | | 29 | 30 | 31 | 32 | + + | 33 | 34 | 35 | 36 | | 37 | 38 | 39 | 40 | + | 41 | 42 | 43 | 44 | | 45 | 46 | 47 | 48 | + | 49 | 50 | 51 | 52 | | 53 | 54 | 55 | 56 | + | 57 | 58 | 59 | 60 | | 61 | 62 | 63 | 64 | + + Patches: shape (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size) + + .. code-block:: text + + | 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | + | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | + + | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | + | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | + + | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | + | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | + + | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | + | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 | + + token: shape (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim) + + .. code-block:: text + + | 1 | 2 | 9 | 10 | | 3 | 4 | 11 | 12 | | 17 | 18 | 25 | 26 | | 19 | 20 | 27 | 28 | + | ... continuation of data ... + | ... continuation of data ... + | 37 | 38 | 45 | 46 | | 39 | 40 | 47 | 48 | | 53 | 54 | 61 | 62 | | 55 | 56 | 63 | 64 | + + For the positional embeddings: + + Same for every tile, different for every token. + + - :class:`torchtune.models.clip._position_embeddings.TokenPositionalEmbedding` + - :class:`torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding` + + .. code-block:: text + + | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | + | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | + | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | + | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | + + | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | + | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | + | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | + | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | + + Different for every tile, different for every token. + + - :class:`torchtune.models.clip._position_embeddings.TiledTokenPositionalEmbedding` + + .. code-block:: text + + | 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | + | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | + + | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | + | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | + + | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | + | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | + + | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | + | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 | + + different for every tile, same for every token within a tile. + + - :class:`torchtune.models.clip._position_embeddings.TilePositionalEmbedding` + + .. code-block:: text + + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + + Args: + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + with shape (40, 40) each. + num_layers (int): The number of transformer layers. + embed_dim (int): The dimensionality of each patch embedding (token). + layer (nn.Module): The transformer layer module. + token_pos_embedding (nn.Module): The token positional embedding module. + pre_tile_pos_embed (Optional[nn.Module]): The pre-tile positional embedding module. It should be + None if your image was not tile-cropped in advance. + post_tile_pos_embed (Optional[nn.Module]): The post-tile positional embedding module. It should be + None if your image was not tile-cropped in advance. + cls_projection (Optional[nn.Module]): The CLS projection module. It should take an input tensor + of shape (bsz * n_tiles, n_tokens, embed_dim) and output a tensor of shape + (bsz * n_tiles, cls_output_dim). If provided, only the CLS token projection will be + outputted, instead of all tokens. + out_indices (Optional[List[int]]): The indices of hidden layers to return. + If provided, it will return the intermediate results of the transformer layers + before they go through a next layer. For example, ``out_indices=[0,3]`` will + return the tokens before they go through the first and fourth layers. + in_channels (int): The number of image input channels. + + Raises: + ValueError: If `tile_size` is not greater than 0. + ValueError: If `patch_size` is not greater than 0. + ValueError: If `len(out_indices)` is greater than `num_layers`. + """ + + def __init__( + self, + patch_size: int, + tile_size: int, + num_layers: int, + embed_dim: int, + layer: nn.Module, + token_pos_embedding: nn.Module, + pre_tile_pos_embed: Optional[nn.Module] = None, + post_tile_pos_embed: Optional[nn.Module] = None, + cls_projection: Optional[nn.Module] = None, + out_indices: Optional[List[int]] = None, + in_channels: int = 3, + ) -> None: + super().__init__() + + if tile_size <= 0: + raise ValueError("tile_size must be > 0") + if patch_size <= 0: + raise ValueError("patch_size must be > 0") + if out_indices and (len(out_indices) > num_layers): + raise ValueError( + f"len(out_indices) must be <= num_layers. Got {out_indices=} and {num_layers=}" + ) + + # constants + patch_grid_size = tile_size // patch_size + self.patches_per_tile = patch_grid_size**2 + self.out_indices = out_indices + if not out_indices: + self.out_indices = [] + + # input modules + self.pre_tile_pos_embed = pre_tile_pos_embed + self.post_tile_pos_embed = post_tile_pos_embed + self.token_pos_embedding = token_pos_embedding + + self.cls_projection = cls_projection + self.transformer_layers = _get_clones(layer, num_layers) + + # other modules + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + bias=False, + ) + + self.ln_post = Fp32LayerNorm(embed_dim) + self.ln_pre = Fp32LayerNorm(embed_dim) + + self.cls_token_embedding = CLSEmbedding(embed_dim) + + def get_image_tokens_per_tile(self): + return self.patches_per_tile + 1 # +1 for CLS token + + def forward( + self, images: torch.Tensor, aspect_ratio: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Processes images and returns the tokens and hidden states. + + Multiple images per sample: we add a dimension n_imgs to the input. This is useful when a single + sample constains multiple images, for example: + + - sample 1: " what animal is this?" + - sample 2: "I like more than " + + In this case, sample 1 has one image, and sample 2 has two images. max_n_imgs = max(2,1) = 2. + So your input should have shape (bsz=2, n_imgs=2, num_tiles, n_channels, tile_size, tile_size). + + Notice that to batch it, you will have to pad n_imgs to max_n_imgs and max_num_tiles. + + Args: + images (torch.Tensor): Tensor with shape (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size). + aspect_ratio (Optional[torch.Tensor]): Tensor with shape (bsz, n_imgs, 2). If all + images have a single tile, i.e. they were not tile-cropped, it should be None. + Used to calculate the positional embeddings for the tiles. + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: A tuple: (x, hidden_states), + where x is a torch.tensor of shape (bsz, n_imgs, n_tiles, n_tokens, embed_dim) and + hidden_states has shape is a list of len(out_indices) torch.tensor with shape + (bsz, n_imgs, n_tiles, n_tokens, embed_dim). + + Raises: + ValueError: If aspect_ratio is None, but n_tiles > 1 in the batch. + + Examples: + + >>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop + >>> from torchtune.modules import VisionTransformer + >>> + >>> num_channels = 3 + >>> image_size = (800,400) + >>> tile_size = 400 + >>> patch_size=40 + >>> patch_grid_size = tile_size // patch_size + >>> + >>> # for details about preprocessing, please check + >>> # torchtune.models.clip._transforms.CLIPImageTransform + >>> + >>> # create a random image + >>> image = torch.rand(num_channels, image_size[0], image_size[1]) + >>> + >>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400) + >>> tile_cropped_image = tile_crop(image, tile_size) + >>> aspect_ratio = torch.tensor([2,1]) + >>> + >>> # make it a batch of 1 image + >>> batch_image = tile_cropped_image.unsqueeze(0) + >>> batch_aspect_ratio = aspect_ratio.unsqueeze(0) + >>> + >>> # make it have only 1 image per sample + >>> batch_image = tile_cropped_image.unsqueeze(1) + >>> batch_aspect_ratio = aspect_ratio.unsqueeze(1) + >>> + >>> # For a detailed example, please check + >>> # torchtune.models.clip._position_embeddings.clip_vision_encoder + >>> # model = VisionTransformer( + ... # out_indices = [1,2,3,4,5], + ... # patch_size=40, + ... # patch_grid_size = patch_grid_size, + ... # embed_dim = 32, + ... # num_layers = 6, + ... # in_channels = num_channels, + ... # ...) + >>> + >>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio) + >>> + >>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) + >>> print(x.shape) + torch.Size([1, 1, 2, 101, 32]) + >>> + >>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim) + >>> print(len(hidden_states)) + 5 + """ + hidden_states = [] + + # parse inputs + bsz, n_imgs, n_tiles, nch, w, h = images.shape + bsz_and_n_imgs = bsz * n_imgs + + # if aspect_ratio is not provided, it defaults to one tile [1,1] + if aspect_ratio is None: + aspect_ratio = torch.ones( + (bsz_and_n_imgs, 2), dtype=torch.int, device=images.device + ) + if n_tiles > 1: + raise ValueError( + f"aspect_ratio was not provided, but found n_tiles>1 for {images.shape=}. Please provide aspect_ratio." + ) + + images = images.reshape(bsz_and_n_imgs * n_tiles, nch, w, h) + aspect_ratio = aspect_ratio.reshape(bsz_and_n_imgs, 2) + + # patch embeddings (tokens) + # A tile becomes a grid of patch_grid_size X patch_grid_size patches + # these patches are flatenned, and called tokens from here on. + + # out: (bsz * n_imgs * n_tiles, embed_dim, patch_grid_size, patch_grid_size) + x = self.conv(images) + + # out: (bsz * n_imgs, n_tiles, n_tokens, embed_dim) + x = x.reshape(bsz_and_n_imgs, n_tiles, -1, self.patches_per_tile).permute( + 0, 1, 3, 2 + ) + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + + # pre_tile_pos_embed + if self.pre_tile_pos_embed: + x = self.pre_tile_pos_embed(x, aspect_ratio) + + # insert cls token + x = self.cls_token_embedding(x) + n_tokens += 1 + + # token_pos_embedding + x = self.token_pos_embedding(x, aspect_ratio) + + # norm + x = self.ln_pre(x) + + # transformer with optional hidden layer outputs + x = x.reshape(bsz_and_n_imgs, n_tiles * n_tokens, embed_dim) + for layer_idx, transformer_layer in enumerate(self.transformer_layers): + if layer_idx in self.out_indices: + hidden_states.append( + x.reshape(bsz, n_imgs, n_tiles, n_tokens, embed_dim) + ) + x = transformer_layer(x) + + # norm + x = self.ln_post(x) + + # post_tile_pos_embed + if self.post_tile_pos_embed: + x = x.reshape(bsz_and_n_imgs, n_tiles, n_tokens, embed_dim) + x = self.post_tile_pos_embed(x, aspect_ratio) + + # reshape output + x = x.reshape(bsz, n_imgs, n_tiles, n_tokens, embed_dim) + + # cls token projection. n_tokens becomes 1 + if self.cls_projection: + x = self.cls_projection(x) + + return x, hidden_states + + +class CLSEmbedding(nn.Module): + """ + Adds a CLS token to every tile in an image. + + Notice that tile is different from patch (token). An image is divided into tiles during pre-processing, + and patches are the outcome of the convolution in the ViT applied to each tile. + + Args: + embed_dim (int): The dimensionality of the input patch embedding. + """ + + def __init__(self, embed_dim: int) -> None: + super().__init__() + + scale = embed_dim**-0.5 + self.cls_embedding = nn.Parameter(scale * torch.randn(embed_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # add 1 CLS token to every tile + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + cls_emb = self.cls_embedding.broadcast_to(bsz_and_n_imgs, n_tiles, 1, embed_dim) + return torch.cat([cls_emb, x], dim=2) + + +class CLSProjection(nn.Module): + """ + Linear projection of the CLS token. + + Args: + embed_dim (int): The dimensionality of the input patch embedding. + cls_output_dim (int): The dimensionality of the output projection. + """ + + def __init__(self, embed_dim: int, cls_output_dim: int) -> None: + super().__init__() + + scale = embed_dim**-0.5 + self.cls_output_dim = cls_output_dim + self.projection = nn.Parameter(scale * torch.randn(embed_dim, cls_output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, n_imgs, n_tiles, n_tokens, embed_dim = x.shape + x = x.reshape(bsz * n_imgs * n_tiles, n_tokens, embed_dim) + + # out: (bsz * n_tiles, cls_output_dim) + x = x[:, 0, :] @ self.projection + + # num_tokens becomes 1 because we only return the CLS token projection + x = x.reshape(bsz, n_imgs, n_tiles, 1, self.cls_output_dim) + return x diff --git a/training/torchtune/recipe_interfaces.py b/training/torchtune/recipe_interfaces.py new file mode 100644 index 0000000..68e1dfa --- /dev/null +++ b/training/torchtune/recipe_interfaces.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Protocol + + +class FTRecipeInterface(Protocol): + """ + This class provides a loose structure which every LLM fine-tuning recipe + should follow. Please note that the interface itself should not be a vehicle for + code reuse. torchtune strictly prohibits implementation inheritance in the codebase. + + A few notes about the design and the need for this interface: + - This interface is meant to help recipe-writers organize their code in a way + which is easy to read, understand and extend. Minimizing code duplication is not + the goal. Recipe-writers are encouraged to copy-paste-modify. + + - This interface is not meant to add constraints. If the interface comes in the + way of doing stuff, it needs to be updated or a new interface should be + written to support what might be a new "family" of recipes. + """ + + def load_checkpoint(self, **kwargs) -> None: + """ + Responsible for loading ALL of the state for the recipe from the + checkpoint file, including state for the model, optimizer, dataloader and training + parameters such as the epoch and seed. + """ + ... + + def setup(self, **kwargs) -> None: + """ + Responsible for setting up all of the components necessary for training. This includes + model, optimizer, loss function and dataloader. + """ + ... + + def train(self, **kwargs) -> None: + """ + All of the training logic, including the core loop, loss computation, gradient + accumulation, and backward. + """ + ... + + def save_checkpoint(self, **kwargs) -> None: + """ + Responsible for saving ALL of the state for the recipe, + including state for the model, optimizer, dataloader and training + parameters such as the epoch and seed. + """ + ... + + def cleanup(self, **kwargs) -> None: + """ + Any cleaning up needed for the recipe. + """ + ... + + +class EvalRecipeInterface(Protocol): + """ + This class provides a loose structure which every LLM evaluation recipe + should follow. Please note that the interface itself should not be a vehicle for + code reuse. torchtune strictly prohibits implementation inheritance in the codebase. + """ + + def load_checkpoint(self, **kwargs) -> None: + """ + Responsible for loading ALL of the state for the recipe from the + checkpoint file. + """ + ... + + def setup(self, **kwargs) -> None: + """ + Responsible for setting up all of the components necessary for evaluation. + """ + ... + + def evaluate(self, **kwargs) -> None: + """ + All of the evaluation logic, including reporting. + """ + ... diff --git a/training/torchtune/utils/__init__.py b/training/torchtune/utils/__init__.py new file mode 100644 index 0000000..9414cc6 --- /dev/null +++ b/training/torchtune/utils/__init__.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._checkpointing import ( # noqa + Checkpointer, + FullModelHFCheckpointer, + FullModelMetaCheckpointer, + FullModelTorchTuneCheckpointer, + ModelType, +) + +from ._device import get_device +from ._distributed import ( # noqa + contains_fsdp, + FSDPPolicyType, + get_full_finetune_fsdp_wrap_policy, + get_full_model_state_dict, + get_full_optimizer_state_dict, + get_world_size_and_rank, + init_distributed, + is_distributed, + load_from_full_model_state_dict, + load_from_full_optimizer_state_dict, + lora_fsdp_wrap_policy, + prepare_model_for_fsdp_with_meta_device, + set_torch_num_threads, + validate_no_params_on_meta_device, +) +from ._generation import generate, generate_next_token # noqa +from ._profiler import ( + DEFAULT_PROFILE_DIR, + DEFAULT_PROFILER_ACTIVITIES, + DEFAULT_SCHEDULE, + DEFAULT_TRACE_OPTS, + DummyProfiler, + PROFILER_KEY, + setup_torch_profiler, +) +from ._version import torch_version_ge +from .argparse import TuneRecipeArgumentParser +from .collate import padded_collate +from .constants import ( # noqa + ADAPTER_CONFIG, + ADAPTER_KEY, + EPOCHS_KEY, + MAX_STEPS_KEY, + MODEL_KEY, + OPT_KEY, + RNG_KEY, + SEED_KEY, + STEPS_KEY, + TOTAL_EPOCHS_KEY, +) +from .logging import get_logger +from .memory import ( # noqa + cleanup_before_training, + create_optim_in_bwd_wrapper, + get_memory_stats, + log_memory_stats, + OptimizerInBackwardWrapper, + register_optim_in_bwd_hooks, + set_activation_checkpointing, +) +from .pooling import get_unmasked_sequence_lengths + +from .precision import get_dtype, set_default_dtype, validate_expected_param_dtype +from .quantization import get_quantizer_mode +from .seed import set_seed + +__all__ = [ + "get_memory_stats", + "FSDPPolicyType", + "log_memory_stats", + "get_device", + "get_dtype", + "get_logger", + "get_world_size_and_rank", + "init_distributed", + "is_distributed", + "lora_fsdp_wrap_policy", + "get_full_finetune_fsdp_wrap_policy", + "padded_collate", + "get_unmasked_sequence_lengths", + "set_activation_checkpointing", + "set_default_dtype", + "set_seed", + "validate_expected_param_dtype", + "TuneRecipeArgumentParser", + "torch_version_ge", + "OptimizerInBackwardWrapper", + "create_optim_in_bwd_wrapper", + "register_optim_in_bwd_hooks", + "DEFAULT_PROFILE_DIR", + "DEFAULT_PROFILER_ACTIVITIES", + "DEFAULT_SCHEDULE", + "DEFAULT_TRACE_OPTS", + "DummyProfiler", + "PROFILER_KEY", + "setup_torch_profiler", + "get_quantizer_mode", + "generate", + "generate_next_token", +] diff --git a/training/torchtune/utils/_checkpointing/__init__.py b/training/torchtune/utils/_checkpointing/__init__.py new file mode 100644 index 0000000..2c9a83d --- /dev/null +++ b/training/torchtune/utils/_checkpointing/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Union + +from ._checkpointer import ( # noqa + FullModelHFCheckpointer, + FullModelMetaCheckpointer, + FullModelTorchTuneCheckpointer, +) +from ._checkpointer_utils import ModelType # noqa + + +Checkpointer = Union[ + FullModelHFCheckpointer, + FullModelMetaCheckpointer, + FullModelTorchTuneCheckpointer, +] + +__all__ = [ + "FullModelHFCheckpointer", + "FullModelMetaCheckpointer", + "FullModelTorchTuneCheckpointer", + "ModelType", + "Checkpointer", +] diff --git a/training/torchtune/utils/_checkpointing/_checkpointer.py b/training/torchtune/utils/_checkpointing/_checkpointer.py new file mode 100644 index 0000000..8fa5520 --- /dev/null +++ b/training/torchtune/utils/_checkpointing/_checkpointer.py @@ -0,0 +1,810 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import json +import os + +from pathlib import Path +from typing import Any, Dict, List, Optional, Protocol + +import torch +from safetensors.torch import save_file +from torchtune import utils + +from torchtune.models import convert_weights +from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf +from torchtune.models.qwen2 import qwen2_hf_to_tune, qwen2_tune_to_hf +from torchtune.modules.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf +from torchtune.utils._checkpointing._checkpointer_utils import ( + get_path, + ModelType, + safe_torch_load, + save_config, +) +from torchtune.utils.logging import get_logger + +logger = get_logger("DEBUG") + + +class _CheckpointerInterface(Protocol): + """ + Interface implemented by Checkpointers in torchtune. + + torchtune checkpointers are designed to be composable components which can be plugged + into any training recipe. Each checkpointer supports a specific set of models and training + scenarios making these easy to understand, debug and extend. For example, the + ``FullModelCheckpointer``s are used for loading and saving all of the model weights. + This checkpointer can be used for Full-Finetuning scenarios or PEFT where the output is a + merged checkpoint. In case the current suite of checkpointers are inadequate, + users are encouraged to implement their own and contribute back to torchtune. + + torchtune is also designed to be "state-dict invariant". This means the checkpointer + ensures that the output checkpoint has the same format as the original checkpoint i.e. + the output checkpoint has the same keys split across the same number of files as the original + checkpoint. Being "state-dict invariant" allows users to seamlessly use torchtune checkpoints + with their favorite post-training tools from the open-source ecosystem without writing + torchtune-specific convertors. To be "state-dict invariant", the ``load_checkpoint`` and + ``save_checkpoint`` methods make use of the weight convertors available in + ``torchtune/models/``. + + torchtune Checkpointers support two checkpointing scenarios: + * End-of-training Checkpointing. The model weights at the end of a completed training + run are written out to file. The checkpointer ensures that the output checkpoint + files have the same keys as the input checkpoint file used to begin training. The + checkpointer also ensures that the keys are partitioned across the same number of + files as the original checkpoint. This ensures that the original metadata files can + be used as is, and the output checkpoint can be used with any tool that understands + the original checkpoint format. This includes popular inference engines such as + ``llama.cpp`` and ``gpt-fast``. The output state dict has the following format: + { + "key_1": weight + ... + } + + + Mid-training Chekpointing. In addition to the model checkpoint files, we output an + additional "recipe_state.pt" file for intermediate checkpoints. These are currently + output at the end of each epoch, and contain information such as optimizer state, + number of epochs completed etc which is needed to correctly resume a previously + interrupted training run. The recipe is responsible for constructing the state dict + with the information it needs. The checkpointer extracts the model state dict + (key = "model") and writes everything else out to "recipe_state.pt". To prevent us + from flooding ``output_dir`` with checkpoint files, the recipe state is overwritten + at the end of each epoch. The output state dicts have the following formats: + + Model: + { + "key_1": weight + ... + } + + Recipe State: + { + "optimizer": ..., + "epoch": ..., + ... + } + + """ + + def load_checkpoint(self, **kwargs) -> Dict[str, Any]: + ... + + def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None: + ... + + +class FullModelTorchTuneCheckpointer(_CheckpointerInterface): + """ + Checkpointer which reads and writes checkpoints in a format compatible with + torchtune. No conversion of weights is required. + + Currently this supports reading a single checkpoint file only. This will likely change as + we add support for larger models. + + Args: + checkpoint_dir (str): Directory containing the checkpoint files + checkpoint_files (List[str]): List of checkpoint files to load. Since the checkpointer takes care + of sorting by file ID, the order in this list does not matter + model_type (ModelType): Model type of the model for which the checkpointer is being loaded + output_dir (str): Directory to save the checkpoint files + adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None + recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to + resume training from a previous run. Default is False + + Raises: + ValueError: If more than one checkpoint file is provided + ValueError: If the checkpoint file does not have a .pt extension + ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None + + + """ + + def __init__( + self, + checkpoint_dir: str, + checkpoint_files: List[str], + model_type: ModelType, + output_dir: str, + adapter_checkpoint: Optional[str] = None, + recipe_checkpoint: Optional[str] = None, + resume_from_checkpoint: bool = False, + ) -> None: + # Fail fast if ``checkpoint_files`` is invalid + if len(checkpoint_files) != 1: + raise ValueError( + "Currently we only support reading from a single torchtune checkpoint file. " + f"Got {len(checkpoint_files)} files instead." + ) + + self._checkpoint_dir = Path(checkpoint_dir) + self._checkpoint_path = get_path(self._checkpoint_dir, checkpoint_files[0]) + + if not self._checkpoint_path.suffix == ".pt": + raise ValueError( + f"Checkpoint file {self._checkpoint_path} is not a valid checkpoint file. " + "Checkpointer expects a valid .pt file." + ) + + self._adapter_checkpoint = ( + get_path(self._checkpoint_dir, adapter_checkpoint) + if adapter_checkpoint + else None + ) + + self._resume_from_checkpoint = resume_from_checkpoint + self._model_type = model_type + self._output_dir = Path(output_dir) + + # recipe_checkpoint contains the recipe state. This should be available if + # resume_from_checkpoint is True + self._recipe_checkpoint = None + if self._resume_from_checkpoint: + if recipe_checkpoint is None: + raise ValueError( + "If resume_from_checkpoint is True, recipe_checkpoint file must be provided." + ) + self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint) + + def load_checkpoint(self, weights_only: bool = True) -> Dict[str, Any]: + """ + Load torchtune checkpoint from file. Currently only loading from a single file is supported. + + The output state_dict has the following format, with keys other than "model" only present if + ``resume_from_checkpoint`` is True: + + >>> { + >>> "model": { + >>> "key_1": weight + >>> ... + >>> }, + >>> "optimizer": {...}, + >>> ... + >>> } + + Args: + weights_only (bool): flag passed down to torch.load. We expose this, because quantized models + cannot be loaded with weights_only=True + + Returns: + Dict[str, Any]: state_dict from the input checkpoint + """ + state_dict: Dict[str:Any] = {} + state_dict[utils.MODEL_KEY] = safe_torch_load( + self._checkpoint_path, weights_only=weights_only + ) + + if self._adapter_checkpoint: + adapter_state_dict = safe_torch_load(self._adapter_checkpoint) + state_dict[utils.ADAPTER_KEY] = adapter_state_dict + + if self._resume_from_checkpoint: + recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) + state_dict.update(recipe_state) + return state_dict + + def save_checkpoint( + self, + state_dict: Dict[str, Any], + epoch: int, + intermediate_checkpoint: bool = False, + adapter_only: bool = False, + ) -> None: + """ + Save torchtune checkpoint to file. If ``intermediate_checkpoint`` is True, an additional + checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe + state. The output state dicts have the following formats: + + >>> # Model + >>> { + >>> "key_1": weight + >>> ... + >>> } + >>> + >>> # Recipe state + >>> { + >>> "optimizer": ..., + >>> "epoch": ..., + >>> ... + >>> } + + Args: + state_dict (Dict[str, Any]): State dict with model and (optionally) recipe state + epoch (int): Current epoch number. This is added to the checkpoint file name to ensure + we're not overwriting intermediate checkpoint files + intermediate_checkpoint (bool): If True, save an additional checkpoint file with the + recipe state + adapter_only (bool): If True, only save the adapter weights. Default is False + + + Raises: + ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict. + """ + self._output_dir.mkdir(exist_ok=True) + + # Output file is always a .pt file with the epoch number in the name + if not adapter_only: + checkpoint_file = Path.joinpath( + self._output_dir, f"torchtune_model_{epoch}" + ).with_suffix(".pt") + torch.save(state_dict[utils.MODEL_KEY], checkpoint_file) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB " + f"saved to {checkpoint_file}" + ) + + if utils.ADAPTER_KEY in state_dict: + output_path = Path.joinpath( + self._output_dir, f"adapter_{epoch}" + ).with_suffix(".pt") + torch.save(state_dict[utils.ADAPTER_KEY], output_path) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + # If the recipe state needs to be output, first remove the model state dict + if intermediate_checkpoint: + _ = state_dict.pop(utils.MODEL_KEY) + _ = state_dict.pop(utils.ADAPTER_KEY, None) + _ = state_dict.pop(utils.ADAPTER_CONFIG, None) + output_path = Path.joinpath(self._output_dir, "recipe_state.pt") + torch.save(state_dict, output_path) + logger.info( + "Recipe checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + else: + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." + ) + else: + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." + ) + + +class FullModelHFCheckpointer(_CheckpointerInterface): + """ + Checkpointer which reads and writes checkpoints in HF's format. For LoRA models this includes + saving checkpoints in a format that can be loaded into PEFT via e.g. ``from_pretrained``. Examples include + the Llama-2-7b-hf model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b-hf). + + Note: + HF checkpoint names are usually ordered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure \ + we read the files in the right order, we sort the checkpoint file names before reading. + + Note: + Checkpoint conversion to and from HF's format requires access to model params which are \ + read directly from the ``config.json`` file. This helps ensure we either load the weights \ + correctly or error out in case of discrepancy between the HF checkpoint file and torchtune's \ + model implementations. + + Args: + checkpoint_dir (str): Directory containing the checkpoint files + checkpoint_files (List[str]): List of checkpoint files to load. Since the checkpointer takes care + of sorting by file ID, the order in this list does not matter + model_type (ModelType): Model type of the model for which the checkpointer is being loaded + output_dir (str): Directory to save the checkpoint files + adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None + recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to + resume training from a previous run. Default is False + safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors` + + Raises: + ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None + """ + + def __init__( + self, + checkpoint_dir: str, + checkpoint_files: List[str], + model_type: ModelType, + output_dir: str, + adapter_checkpoint: Optional[str] = None, + recipe_checkpoint: Optional[str] = None, + resume_from_checkpoint: bool = False, + safe_serialization: bool = False, + ) -> None: + self._checkpoint_dir = Path(checkpoint_dir) + self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files) + self._adapter_checkpoint = ( + get_path(self._checkpoint_dir, adapter_checkpoint) + if adapter_checkpoint + else None + ) + + self._model_type = ModelType[model_type] + self._output_dir = Path(output_dir) + self._resume_from_checkpoint = resume_from_checkpoint + self._safe_serialization = safe_serialization + + # weight_map contains the state_dict key -> checkpoint file mapping so we can correctly + # parition the state dict into output checkpoint files. This is updated during checkpoint + # load + self._weight_map: Dict[str, str] = None + + # the config.json file contains model params needed for state dict conversion + self._config = json.loads( + Path.joinpath(self._checkpoint_dir, "config.json").read_text() + ) + + # save config.json to output_dir + save_config(self._output_dir, self._config) + + # recipe_checkpoint contains the recipe state. This should be available if + # resume_from_checkpoint is True + self._recipe_checkpoint = None + if self._resume_from_checkpoint: + if recipe_checkpoint is None: + raise ValueError( + "If resume_from_checkpoint is True, recipe_checkpoint file must be provided." + ) + self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint) + + def _validate_hf_checkpoint_files(self, checkpoint_files: List[str]) -> List[Path]: + """ + Validates that the checkpoint files exist and sorts based on ID. + """ + checkpoint_paths: List[Path] = [] + for f in checkpoint_files: + checkpoint_path = get_path(self._checkpoint_dir, f) + checkpoint_paths.append(checkpoint_path) + return sorted(checkpoint_paths) + + def load_checkpoint(self) -> Dict[str, Any]: + """ + Load HF checkpoint from file. + + The keys and weights from across all checkpoint files are merged into a single state_dict. + We preserve the "state_dict key" <-> "checkpoint file" mapping in weight_map so we can + write the state dict correctly in ``save_checkpoint``. + + Before returning, the model state dict is converted to a torchtune-compatible format using + the appropriate convert_weights function (depending on ``self._model_type``). + + Returns: + state_dict (Dict[str, Any]): torchtune checkpoint state dict + + Raises: + ValueError: If the values in the input state_dict are not Tensors + """ + self._weight_map = {} + + # merged state_dict contains keys and weights from all the checkpoint files + merged_state_dict: Dict[str, torch.Tensor] = {} + + # converted_state_dict is the final state_dict passed to the recipe after the + # keys are converted into the torchtune format. This optionally also contains + # the recipe state and adapter weights + converted_state_dict: Dict[str, Dict[str, torch.Tensor]] = {} + + # _checkpoint_paths are already sorted so simply enumerate to generate the right id + for cpt_idx, cpt_path in enumerate(self._checkpoint_paths): + state_dict = safe_torch_load(cpt_path) + for key, value in state_dict.items(): + # Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption + # will break recipe code + if not isinstance(value, torch.Tensor): + raise ValueError( + f"Expected all values in the state dict to be torch.Tensor. " + f"Found {type(value)} instead." + ) + # idx is written in the 4 digit format (eg: 0001, 0002, etc.) + self._weight_map[key] = f"{cpt_idx + 1:04}" + merged_state_dict.update(state_dict) + + # delete the state_dict to free up memory; TODO check if this del is needed + del state_dict + gc.collect() + if self._model_type == ModelType.PHI3_MINI: + logger.warning( + "Converting Phi-3 Mini weights from HF format." + "Note that conversion of adapter weights into PEFT format is not supported." + ) + converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict) + elif self._model_type == ModelType.REWARD: + converted_state_dict[utils.MODEL_KEY] = reward_hf_to_tune( + merged_state_dict, + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + ) + elif self._model_type == ModelType.QWEN2: + converted_state_dict[utils.MODEL_KEY] = qwen2_hf_to_tune( + merged_state_dict, + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + tie_word_embeddings=self._config["tie_word_embeddings"], + ) + else: + converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune( + merged_state_dict, + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) + + if self._adapter_checkpoint: + adapter_state_dict = safe_torch_load(self._adapter_checkpoint) + converted_state_dict[utils.ADAPTER_KEY] = adapter_state_dict + + if self._resume_from_checkpoint: + recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) + converted_state_dict.update(recipe_state) + return converted_state_dict + + def save_checkpoint( + self, + state_dict: Dict[str, Any], + epoch: int, + intermediate_checkpoint: bool = False, + adapter_only: bool = False, + ) -> None: + """ + Save HF checkpoint to file. If ``intermediate_checkpoint`` is True, an additional + checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe + state. + + The state_dict is first converted back to the HF format and then partitioned based on the + ``_weight_map`` into separate checkpoint files. + + Args: + state_dict (Dict[str, Any]): Checkpoint state dict to be written out to file + epoch (int): Epoch number. Used to create the checkpoint file name + intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state + and (if applicable) adapter weights are created. Default is False + adapter_only (bool): If True, only save the adapter weights. Default is False + + Raises: + ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict. + """ + self._output_dir.mkdir(exist_ok=True) + + # convert the state_dict back to hf format; do this inplace + if not adapter_only: + if self._model_type == ModelType.PHI3_MINI: + state_dict[utils.MODEL_KEY] = phi3_tune_to_hf( + state_dict[utils.MODEL_KEY] + ) + elif self._model_type == ModelType.REWARD: + state_dict[utils.MODEL_KEY] = reward_tune_to_hf( + state_dict[utils.MODEL_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + ) + elif self._model_type == ModelType.QWEN2: + state_dict[utils.MODEL_KEY] = qwen2_tune_to_hf( + state_dict[utils.MODEL_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + tie_word_embeddings=self._config["tie_word_embeddings"], + ) + else: + state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf( + state_dict[utils.MODEL_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) + + # split the state_dict into separate dicts, one for each output checkpoint file + split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {} + for key, weight in state_dict[utils.MODEL_KEY].items(): + cpt_idx = self._weight_map[key] + if cpt_idx not in split_state_dicts: + split_state_dicts[cpt_idx] = {} + split_state_dicts[cpt_idx].update({key: weight}) + + # write the partitioned state dicts to the right checkpoint file + for cpt_idx, model_state_dict in split_state_dicts.items(): + if not self._safe_serialization: + output_path = Path.joinpath( + self._output_dir, f"hf_model_{cpt_idx}_{epoch}" + ).with_suffix(".pt") + torch.save(model_state_dict, output_path) + else: + output_path = Path.joinpath( + self._output_dir, + f"model-0{cpt_idx}-of-0{list(split_state_dicts.keys())[-1]}_{epoch}", + ).with_suffix(".safetensors") + save_file(model_state_dict, output_path, metadata={"format": "pt"}) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + + if utils.ADAPTER_KEY in state_dict: + # Save torchtune format adapter weights even if we save PEFT format + # This way we can resume no matter what (and memory footprint of adapter weights is small) + output_path = Path.joinpath( + self._output_dir, f"adapter_{epoch}" + ).with_suffix(".pt") + torch.save(state_dict[utils.ADAPTER_KEY], output_path) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + + if self._model_type == ModelType.PHI3_MINI: + logger.warning( + "Saving Phi-3 Mini adapter weights to PEFT format is not supported, saving to torchtune format instead" + ) + else: + state_dict[ + utils.ADAPTER_KEY + ] = convert_weights.tune_to_peft_adapter_weights( + state_dict[utils.ADAPTER_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) + peft_output_path = Path.joinpath( + self._output_dir, "adapter_model" + ).with_suffix(".bin") + torch.save(state_dict[utils.ADAPTER_KEY], peft_output_path) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {peft_output_path}" + ) + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + if utils.ADAPTER_CONFIG in state_dict: + if self._model_type == ModelType.PHI3_MINI: + logger.warning( + "PEFT integration for Phi-3 Mini is not supported, skipping adapter config save" + ) + else: + state_dict[ + utils.ADAPTER_CONFIG + ] = convert_weights.tune_to_peft_adapter_config( + state_dict[utils.ADAPTER_CONFIG] + ) + output_path = Path.joinpath(self._output_dir, "adapter_config.json") + with open(output_path, "w") as f: + json.dump(state_dict[utils.ADAPTER_CONFIG], f) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + + # If the recipe state needs to be output, first remove the model state dict + # and if it exists, remove the adapter state dict as well + if intermediate_checkpoint: + _ = state_dict.pop(utils.MODEL_KEY) + _ = state_dict.pop(utils.ADAPTER_KEY, None) + _ = state_dict.pop(utils.ADAPTER_CONFIG, None) + output_path = Path.joinpath(self._output_dir, "recipe_state.pt") + torch.save(state_dict, output_path) + logger.info( + "Recipe checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + else: + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." + ) + else: + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." + ) + + +class FullModelMetaCheckpointer(_CheckpointerInterface): + """ + Checkpointer which reads and writes checkpoints in Meta's format. Examples include + the Llama-2-7b model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b) + + Currently we support reading from a single checkpoint file only. Support for reading from + sharded checkpoints is WIP. + + Args: + checkpoint_dir (str): Directory containing the checkpoint files + checkpoint_files (List[str]): List of checkpoint files to load. Currently this checkpointer only + supports loading a single checkpoint file. + model_type (ModelType): Model type of the model for which the checkpointer is being loaded + output_dir (str): Directory to save the checkpoint files + adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None + recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to + resume training from a previous run. Default is False + + Raises: + ValueError: If ``checkpoint_files`` is not a list of length 1 + ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None + """ + + def __init__( + self, + checkpoint_dir: str, + checkpoint_files: List[str], + model_type: ModelType, + output_dir: str, + adapter_checkpoint: Optional[str] = None, + recipe_checkpoint: Optional[str] = None, + resume_from_checkpoint: bool = False, + ) -> None: + # Fail fast if ``checkpoint_files`` is invalid + if len(checkpoint_files) != 1: + raise ValueError( + "Currently we only support reading from a single torchtune checkpoint file. " + f"Got {len(checkpoint_files)} files instead." + ) + + self._checkpoint_dir = Path(checkpoint_dir) + self._checkpoint_path = get_path(self._checkpoint_dir, checkpoint_files[0]) + + self._adapter_checkpoint = ( + get_path(self._checkpoint_dir, adapter_checkpoint) + if adapter_checkpoint + else None + ) + + self._resume_from_checkpoint = resume_from_checkpoint + self._model_type = model_type + self._output_dir = Path(output_dir) + + # recipe_checkpoint contains the recipe state. This should be available if + # resume_from_checkpoint is True + self._recipe_checkpoint = None + if self._resume_from_checkpoint: + if recipe_checkpoint is None: + raise ValueError( + "If resume_from_checkpoint is True, recipe_checkpoint file must be provided." + ) + self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint) + + def load_checkpoint(self) -> Dict[str, Any]: + """ + Load Meta checkpoint from file. Currently only loading from a single file is supported. + """ + state_dict: Dict[str:Any] = {} + model_state_dict = safe_torch_load(self._checkpoint_path) + state_dict[utils.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict) + + if self._adapter_checkpoint: + adapter_state_dict = safe_torch_load(self._adapter_checkpoint) + state_dict[utils.ADAPTER_KEY] = adapter_state_dict + + if self._resume_from_checkpoint: + recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) + state_dict.update(recipe_state) + return state_dict + + def save_checkpoint( + self, + state_dict: Dict[str, Any], + epoch: int, + intermediate_checkpoint: bool = False, + adapter_only: bool = False, + ) -> None: + """ + Save Meta checkpoint to file. If ``intermediate_checkpoint`` is True, an additional + checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe + state. + + Args: + state_dict (Dict[str, Any]): Checkpoint state dict to be written out to file + epoch (int): Epoch number. Used to create the checkpoint file name + intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state + and (if applicable) adapter weights are created. Default is False + adapter_only (bool): If True, only save the adapter weights. Default is False + + Raises: + ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict. + """ + self._output_dir.mkdir(exist_ok=True) + + if not adapter_only: + model_state_dict = state_dict[utils.MODEL_KEY] + state_dict[utils.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict) + + # Output file is always a .pt file with the epoch number in the name + checkpoint_file = Path.joinpath( + self._output_dir, f"meta_model_{epoch}" + ).with_suffix(".pt") + torch.save(state_dict[utils.MODEL_KEY], checkpoint_file) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB " + f"saved to {checkpoint_file}" + ) + + if utils.ADAPTER_KEY in state_dict: + output_path = Path.joinpath( + self._output_dir, f"adapter_{epoch}" + ).with_suffix(".pt") + torch.save(state_dict[utils.ADAPTER_KEY], output_path) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + # If the recipe state needs to be output, first remove the model state dict + # and if it exists, remove the adapter state dict as well + if intermediate_checkpoint: + _ = state_dict.pop(utils.MODEL_KEY) + _ = state_dict.pop(utils.ADAPTER_KEY, None) + _ = state_dict.pop(utils.ADAPTER_CONFIG, None) + output_path = Path.joinpath(self._output_dir, "recipe_state.pt") + torch.save(state_dict, output_path) + logger.info( + "Recipe checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) + else: + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." + ) + else: + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." + ) diff --git a/training/torchtune/utils/_checkpointing/_checkpointer_utils.py b/training/torchtune/utils/_checkpointing/_checkpointer_utils.py new file mode 100644 index 0000000..962dab9 --- /dev/null +++ b/training/torchtune/utils/_checkpointing/_checkpointer_utils.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from enum import Enum +from pathlib import Path +from typing import Any, Dict + +import torch +from safetensors import safe_open + + +class ModelType(Enum): + """ModelType is used by the checkpointer to distinguish between different model architectures. + + If you are adding a new model that follows a different format than those in the repo already, + you can add a new ModelType to gate on weight conversion logic unique to that model. + + Attributes: + GEMMA (str): Gemma family of models. See :func:`~torchtune.models.gemma.gemma` + LLAMA2 (str): Llama2 family of models. See :func:`~torchtune.models.llama2.llama2` + LLAMA3 (str): Llama3 family of models. See :func:`~torchtune.models.llama3.llama3` + MISTRAL (str): Mistral family of models. See :func:`~torchtune.models.mistral.mistral` + PHI3_MINI (str): Phi-3 family of models. See :func:`~torchtune.models.phi3.phi3` + REWARD (str): A Llama2, Llama3, or Mistral model with a classification head projecting + to a single class for reward modelling. + See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b` + QWEN2 (str): Qwen2 family of models. See :func:`~torchtune.models.qwen2.qwen2` + + Example: + >>> # Usage in a checkpointer class + >>> def load_checkpoint(self, ...): + >>> ... + >>> if self._model_type == MY_NEW_MODEL: + >>> state_dict = my_custom_state_dict_mapping(state_dict) + """ + + GEMMA: str = "gemma" + LLAMA2: str = "llama2" + LLAMA3: str = "llama3" + MISTRAL: str = "mistral" + PHI3_MINI: str = "phi3_mini" + REWARD: str = "reward" + QWEN2: str = "qwen2" + + +def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: + """ + Utility to recover and validate the path for a given file within a given directory. + + Args: + input_dir (Path): Directory containing the file + filename (str): Name of the file + missing_ok (bool): Whether to raise an error if the file is missing. + + Returns: + Path: Path to the file + + Raises: + ValueError: If the file is missing and missing_ok is False. + """ + if not input_dir.is_dir(): + raise ValueError(f"{input_dir} is not a valid directory.") + + file_path = Path.joinpath(input_dir, filename) + + # If missing_ok is False, raise an error if the path is invalid + if not missing_ok and not file_path.is_file(): + raise ValueError(f"No file with name: {filename} found in {input_dir}.") + return file_path + + +def safe_torch_load( + checkpoint_path: Path, weights_only: bool = True, mmap: bool = True +) -> Dict[str, Any]: + """ + Utility to load a checkpoint file onto CPU in a safe manner. Provides separate handling for + safetensors files. + + Args: + checkpoint_path (Path): Path to the checkpoint file. + weights_only (bool): Whether to load only tensors, primitive types, and dictionaries + (passthrough to torch.load). Default: True + mmap (bool): Whether to mmap from disk into CPU memory. Default: True + + Returns: + Dict[str, Any]: State dict from the checkpoint file. + + Raises: + ValueError: If the checkpoint file is not found or cannot be loaded. + """ + try: + # convert the path into a string since pathlib Path and mmap don't work + # well together + is_safetensors_file = ( + True if str(checkpoint_path).endswith(".safetensors") else False + ) + if is_safetensors_file: + result = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + state_dict = result + else: + state_dict = torch.load( + str(checkpoint_path), + map_location="cpu", + mmap=mmap, + weights_only=weights_only, + ) + except Exception as e: + raise ValueError(f"Unable to load checkpoint from {checkpoint_path}. ") from e + return state_dict + + +def save_config(path: Path, config: Dict[str, Any]) -> None: + """ + Save a configuration dictionary to a file. + + Args: + path (Path): Path to save the configuration file. + config (Dict[str, Any]): Configuration dictionary to save. + """ + if not path.is_dir(): + path.mkdir(exist_ok=True) + file_path = Path.joinpath(path, "config.json") + if not file_path.exists(): + with open(file_path, "w") as f: + json.dump(config, f) diff --git a/training/torchtune/utils/_device.py b/training/torchtune/utils/_device.py new file mode 100644 index 0000000..4c3ffee --- /dev/null +++ b/training/torchtune/utils/_device.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Optional + +import torch + + +def _get_local_rank() -> Optional[int]: + """Function that gets the local rank from the environment. + + Returns: + local_rank int or None if not set. + """ + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + local_rank = int(local_rank) + return local_rank + + +def _setup_cuda_device(device: torch.device) -> torch.device: + """Function that sets the CUDA device and infers the cuda + index if not set. + + Args: + device (torch.device): The device to set. + + Raises: + RuntimeError: If device index is not available. + + Returns: + device + """ + local_rank = _get_local_rank() or 0 + if device.index is None: + device = torch.device(type="cuda", index=local_rank) + + # Ensure index is available before setting device + if device.index >= torch.cuda.device_count(): + raise RuntimeError( + "The local rank is larger than the number of available GPUs." + ) + + torch.cuda.set_device(device) + return device + + +def _get_device_type_from_env() -> str: + """Function that gets the torch.device based on the current machine. + + This currently only supports CPU, CUDA. + + Returns: + device + """ + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + return device + + +def _validate_device_from_env(device: torch.device) -> None: + """Function that validates the device is correct given the current machine. + This will raise an error if the device is not available or doesn't match the + assigned process device on distributed runs. + + Args: + device (torch.device): The device to validate. + + Raises: + RuntimeError: If the device is not available or doesn't match the assigned process device. + + Returns: + device + """ + local_rank = _get_local_rank() + + # Check if the device index is correct + if device.type == "cuda" and local_rank is not None: + # Ensure device index matches assigned index when distributed training + if device.index != local_rank: + raise RuntimeError( + f"You can't specify a device index when using distributed training. \ + Device specified is {device} but was assigned cuda:{local_rank}" + ) + + # Check if the device is available on this machine + try: + torch.empty(0, device=device) + except RuntimeError as e: + raise RuntimeError( + f"The device {device} is not available on this machine." + ) from e + + +def get_device(device: Optional[str] = None) -> torch.device: + """Function that takes an optional device string, verifies it's correct and available given the machine and + distributed settings, and returns a torch.device. If device string is not provided, this function will + infer the device based on the environment. + + If CUDA is available and being used, this function also sets the CUDA device. + + Args: + device (Optional[str]): The name of the device to use. + + Returns: + torch.device: device. + """ + if device is None: + device = _get_device_type_from_env() + device = torch.device(device) + if device.type == "cuda": + device = _setup_cuda_device(device) + _validate_device_from_env(device) + return device diff --git a/training/torchtune/utils/_distributed.py b/training/torchtune/utils/_distributed.py new file mode 100644 index 0000000..f582153 --- /dev/null +++ b/training/torchtune/utils/_distributed.py @@ -0,0 +1,562 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +from itertools import chain +from typing import Any, Callable, cast, Dict, Set, Tuple, Type + +import torch +import torch.distributed as dist +from packaging import version +from torch import nn + +from torch.distributed._tensor import distribute_tensor, DTensor +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_WRAPPED_MODULE, +) +from torch.distributed.checkpoint.state_dict import _init_optim_state +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.optim import Optimizer +from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 +from torchtune import modules +from torchtune.modules.peft.lora import ( + _lora_a_init_params, + _lora_b_init_params, + LoRALinear, +) + +from torchtune.utils._device import get_device +from torchtune.utils.logging import get_logger + +_log: logging.Logger = get_logger() + + +FSDPPolicyType: Type = Callable[[nn.Module, bool, int], bool] + +FSDPPolicyType.__doc__ = """ + +A datatype for a function that can be used as an FSDP wrapping policy. +In particular, this type denotes a function that can accept an nn.Module, a boolean flag, and an integer +and return a boolean indicating whether the module should be wrapped with FSDP. Objects of this type can +be directly passed into PyTorch FSDP's ``auto_wrap_policy`` argument to specify how FSDP wraps submodules. + +The below function serves as an example of creating and returning a function that obeys the contract of +``FSDPPolicyType``:: + + def get_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], min_num_params: int): + + def my_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], recurse: bool, min_num_params: int) -> bool: + if recurse: + return True + # Wrap layers that are of type in ``modules_to_wrap`` and layers with more than min_num_params + + return isinstance(module, tuple(modules_to_wrap)) or sum(p.numel() for p in module.parameters()) > 1000 + + return functools.partial(my_fsdp_policy, modules_to_wrap=modules_to_wrap) + +Please see documentation of ``auto_wrap_policy`` at https://pytorch.org/docs/stable/fsdp.html for additional details. + +""" + +_valid_distributed_single_node_nnodes = ["1:1", "1"] + + +def _get_sharding_strategy(strategy: str) -> ShardingStrategy: + """Helper function to convert sharding strategy strings to ShardingStrategy enum.""" + return getattr(ShardingStrategy, strategy) + + +def is_distributed() -> bool: + """Check if all environment variables required to initialize torch.distributed are set + and distributed is properly installed. This indicates a distributed run. + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + + Checks the following conditions: + + * torch.distributed is available + * master port and master address environment variables are set + * world size is >1 + * rank environment variable is set + + Returns: + bool: True if all of the above conditions hold, False otherwise. + """ + port = os.environ.get("MASTER_PORT", "") + addr = os.environ.get("MASTER_ADDR", "") + size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", -1)) + avlb = dist.is_available() + return bool(port and addr and size >= 1 and rank >= 0 and avlb) + + +def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcasts a tensor from a source to all other processes. + + Args: + tensor (torch.Tensor): Tensor to broadcast. + src (int, optional): Source rank. Defaults to 0. + + Returns: + torch.Tensor: Broadcasted tensor. + """ + if dist.is_available() and dist.is_initialized(): + device = tensor.device + if dist.get_backend() == "nccl": + tensor = tensor.to(get_device("cuda")) + dist.broadcast(tensor, src=src, group=None) + return tensor.to(device) + else: + return tensor + + +def init_distributed(**kwargs: Dict[str, Any]) -> bool: + """Initialize process group required for ``torch.distributed``. + + Args: + **kwargs (Dict[str, Any]): Additional arguments to pass to torch.distributed.init_process_group. + + Returns: + bool: True if torch.distributed is initialized. + + Raises: + RuntimeError: If torch.distributed is already initialized. + """ + if is_distributed(): + if dist.is_initialized(): + raise RuntimeError("torch.distributed already initialized.") + dist.init_process_group(**kwargs) + return True + else: + return False + + +def set_torch_num_threads() -> None: + """ + Sets the number of threads used by torch to utilize all physical CPU + cores for intra-op parallelism. Currently, this function sets num_threads + to be the number of physical CPU cores divided by the number of GPUs as we + use one process per GPU, and this avoids CPU oversubscription. Note that this is + currently a rough approximation, and doesn't take into account environments where + things like CPU affinity is set. + """ + num_threads = os.cpu_count() // ( + torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + ) + torch.set_num_threads(num_threads) + _log.info(f"Set intra op parallelism no. of threads to {num_threads}") + + +def get_world_size_and_rank() -> Tuple[int, int]: + """Function that gets the current world size (aka total number + of ranks) and rank number of the current process in the default process group. + + Returns: + Tuple[int, int]: world size, rank + """ + if dist.is_available() and dist.is_initialized(): + return torch.distributed.get_world_size(), torch.distributed.get_rank() + else: + return 1, 0 + + +def validate_no_params_on_meta_device(model: nn.Module) -> None: + """ + Utility to validate that model has no params or buffers on meta device. + If a meta param or buffer is found, an error indicating the param name will + be raised. + + Args: + model (nn.Module): model to check for meta params + + Raises: + RuntimeError: If meta params or buffers exist in model + """ + for n, p in chain(model.named_parameters(), model.named_buffers()): + if p.is_meta: + raise RuntimeError(f"Unexpected param or buffer {n} on meta device.") + + +def contains_fsdp(model: nn.Module) -> bool: + """ + Checks if the model contains FSDP. + + Args: + model (nn.Module): Model to check. + + Returns: + bool: True if the model contains FSDP, False otherwise. + """ + return any( + isinstance(m, torch.distributed.fsdp.FullyShardedDataParallel) + for m in model.modules() + ) + + +def _dummy_reset_params(x: nn.Module) -> None: + """ + Dummy method for patching no-op reset_parameters() when using + FSDP with meta device. + """ + return + + +def prepare_model_for_fsdp_with_meta_device(model: nn.Module) -> nn.Module: + """ + Dynamically define reset_parameters on every submodule of the model. For LoRA models, + ensure that the FSDP contract of reset_parameters only modifying a module's directly-owned + parameters is satisfied. More details here: https://github.com/pytorch/pytorch/issues/104187. + + Args: + model (nn.Module): model class to prepare for usage with FSDP and meta device. + + Returns: + nn.Module: Model with reset_parameters defined on every submodule. + In the case of a LoRA model, we override the default reset_parameters of nn.Linear. + + Raises: + RuntimeError: if model contains submodule with non-callable attribute reset_parameters + """ + for k, v in model.named_modules(): + # If the module does not have reset_parameters defined, we define + # a no-op reset_parameters method to satisfy FSDP's contract. + reset_params = getattr(v, "reset_parameters", None) + + if reset_params is not None and not callable(reset_params): + raise RuntimeError( + f"Cannot override existing reset_parameters variable for FSDP init in {k}" + ) + + if reset_params is None: + v.reset_parameters = _dummy_reset_params.__get__(v) + + # This will define reset_parameters for LoRA weight initialization + # directly on any LoRALinear submodules lora_a and lora_b. + if isinstance(v, LoRALinear): + v.lora_a.reset_parameters = _lora_a_init_params.__get__(v.lora_a) + v.lora_b.reset_parameters = _lora_b_init_params.__get__(v.lora_b) + + return model + + +def lora_fsdp_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType: + """ + A default policy for wrapping models trained with LoRA using FSDP. + + FSDP's default behavior is to allocate gradients at the level of FSDP-wrapped modules. + This means that if any parameter in a given FSDP-wrapped module requires gradients, then memory will be + allocated for gradients for the entire module. + + In the case of LoRA, where only LoRA A and B matrices are trainable, this means that + we need to wrap LoRA A and B submodules in their own FSDP units to + maximize memory savings. After this is done, model will also be hierarchically wrapped + based on nn.Module types specified in ``modules_to_wrap``. This function assumes that + (a) LoRA's A and B matrices are the only trainable weights in the entire model, and + (b) we have already set ``requires_grad = True`` on LoRA params. + + Args: + modules_to_wrap (Set[Type]): nn.Module types to recursively wrap + + Returns: + FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel``. Please see + documentation for :const:`~torchtune.utils.FSDPPolicyType` for additional details. + """ + + def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs): + if recurse: + return True + + # Assumes lora_a and lora_b are nn.Linears that are the + # only trainable modules in the entire network. Wraps + # these in separate FSDP unit to work around FSDP allocating + # extra gradient memory when wrapped with other modules. + if hasattr(module, "weight") and module.weight.requires_grad: + return True + + return isinstance(module, tuple(modules_to_wrap)) + + return lora_wrap_fsdp + + +def load_from_full_model_state_dict( + model: "FSDPModule", + full_sd: Dict[str, Any], + device: torch.device, + is_rank_zero: bool, +): + """ + Converting full state dict into a sharded state dict + and loading it into FSDP model + - 'full' means plain tensor + - 'sharded' means `DTensor` where reach rank has a shard of the plain tensor + - `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and + we need to broadcast from rank 0 + """ + meta_sharded_sd = model.state_dict() + sharded_sd = {} + for param_name, full_tensor in full_sd.items(): + sharded_meta_param = meta_sharded_sd.get(param_name) + full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) + if isinstance(sharded_meta_param._local_tensor, NF4Tensor): + full_tensor = to_nf4(full_tensor) + # replicating logic from `_fsdp_param.py`` `_init_sharded_param` + # otherwise `distribute_tensor(DTensor(local=NF4))` + # requires dispatching `c10d.scatter_`` + # long-term solution is `swap_tensor` + mesh = sharded_meta_param.device_mesh + if mesh.ndim > 1: + raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}") + shard_mesh_dim = 0 + shard_world_size = mesh.size(shard_mesh_dim) + shard_rank = cast( + torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) + ).rank() + chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank] + sharded_param = full_tensor.new_zeros(chunk.size()) + sharded_param[: chunk.size(0)].copy_(chunk) + # BC-breaking change to DTensor API in https://github.com/pytorch/pytorch/pull/128112 + # TODO: change to from_local API (need to add view support for NF4) + if version.parse(torch.__version__) >= version.parse("2.4.0.dev20240606"): + sharded_tensor = DTensor( + local_tensor=sharded_param, + spec=DTensorSpec( + mesh=sharded_meta_param.device_mesh, + placements=sharded_meta_param.placements, + tensor_meta=TensorMeta( + shape=sharded_meta_param.size(), + dtype=sharded_meta_param.dtype, + stride=sharded_meta_param.stride(), + ), + ), + requires_grad=sharded_meta_param.requires_grad, + ) + else: + sharded_tensor = DTensor( + sharded_param, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + shape=sharded_meta_param.size(), + dtype=sharded_meta_param.dtype, + requires_grad=sharded_meta_param.requires_grad, + stride=sharded_meta_param.stride(), + ) + + else: + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + sharded_sd[param_name] = nn.Parameter(sharded_tensor) + # choose `assign=True` since we cannot call `copy_` on meta tensor + return model.load_state_dict(sharded_sd, strict=False, assign=True) + + +def get_full_model_state_dict( + model: "FSDPModule", + is_rank_zero: bool, +) -> Dict[str, Any]: + """ + Converting sharded state dict into a full state dict on cpu + Returning non-empty result on rank0 to avoid peaking cpu memory + """ + sharded_sd = model.state_dict() + cpu_state_dict = {} + has_nf4 = any( + isinstance(param._local_tensor, NF4Tensor) for param in model.parameters() + ) + if has_nf4: + from torch.distributed._composable.fsdp.fully_shard import FSDPModule + + # Iterating from lowerer modules to higher + # Unsharding lora adapters before unsharding transformer block + for module_name, module in reversed(list(model.named_modules())): + if not isinstance(module, FSDPModule): + continue + module.unshard(async_op=False) + if is_rank_zero: + module_name = module_name.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "") + for local_fqn, param in module.named_parameters(): + local_fqn = local_fqn.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "") + if len(module_name) > 0: + full_fqn = module_name + "." + local_fqn + else: + full_fqn = local_fqn + if full_fqn in cpu_state_dict: + # Iterate over every param in every module bottoms-up + # When lower TransformerBlock gets unsharded, + # we insert (full_fqn, full_tensor) into cpu_state_dict. + # When higher Transformer gets unsharded, we avoid updating + # params from lower TransformerBlockonly again. Instead, only updating + # tok_embeddings etc that belongs to Transformer + continue + if isinstance(param, NF4Tensor): + # upcasting NF4 to original dtype + param = param.to(param.dtype) + if isinstance(param, DTensor): + raise AssertionError( + f"Internal error: expect unsharded {full_fqn} in plain torch.Tensor but got DTensor." + " Might be a bug in get_full_model_state_dict" + ) + cpu_state_dict[full_fqn] = param.cpu() + module.reshard() + else: + for param_name, sharded_param in sharded_sd.items(): + full_param = sharded_param.full_tensor() + if is_rank_zero: + cpu_state_dict[param_name] = full_param.cpu() + else: + del full_param + return cpu_state_dict + + +def get_full_optimizer_state_dict( + opt: Optimizer, + is_rank_zero: bool, +) -> Dict[str, Any]: + """ + Converting optimizer state from sharded to full + For example, "exp_avg" in AdamW is `DTensor`, + "exp_avg.full_tensor()" converts it to plain tensor on rank 0 + Returning non-empty cpu state dict on rank 0 + """ + sharded_sd = opt.state_dict() + sharded_state = sharded_sd["state"] + full_state = {} + for group_id, sharded_group in sharded_state.items(): + group_state = {} + for attr, sharded_tensor in sharded_group.items(): + if isinstance(sharded_tensor, DTensor): + # "exp_avg" in AdamW is `DTensor` + full_tensor = sharded_tensor.full_tensor() + else: + # "step" in AdamW is plain tensor + full_tensor = sharded_tensor + if is_rank_zero: + group_state[attr] = full_tensor.cpu() + else: + del full_tensor + if is_rank_zero: + full_state[group_id] = group_state + else: + del group_state + if is_rank_zero: + return { + "param_groups": sharded_sd["param_groups"], + "state": full_state, + } + else: + return {} + + +def load_from_full_optimizer_state_dict( + opt: Optimizer, + full_sd: Dict[str, Any], + device: torch.device, +) -> Dict[str, Any]: + """ + Converting full optimizer state to sharded state dict + and loading it into optimizer + """ + PARAMS = "params" # noqa: N806 + _init_optim_state(opt) + param_groups = opt.state_dict()["param_groups"] + state = opt.state_dict()["state"] + + full_param_groups = full_sd["param_groups"] + full_state = full_sd["state"] + + for param_group, full_param_group in zip(param_groups, full_param_groups): + for key, value in full_param_group.items(): + if key == PARAMS: + continue + param_group[key] = value + for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]): + if pid not in state: + continue + param_state = state[pid] + full_param_state = full_state[full_pid] + for attr, full_tensor in full_param_state.items(): + sharded_tensor = param_state[attr] + if isinstance(sharded_tensor, DTensor): + # exp_avg is DTensor + param_state[attr] = distribute_tensor( + full_tensor, + sharded_tensor.device_mesh, + sharded_tensor.placements, + ) + else: + # step is plain tensor + param_state[attr] = full_tensor + opt.load_state_dict( + { + "param_groups": param_groups, + "state": state, + } + ) + + +def get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap: bool, modules_to_wrap: Set[Type] +) -> FSDPPolicyType: + """ + Retrieves an FSDP wrapping policy based on the specified flags ``memory_efficient_fsdp_wrap`` and + ``modules_to_wrap``. Specifically, if ``memory_efficient_fsdp_wrap`` is set to ``True``, the returned + policy will wrap the model's token embedding and output projection in addition to the modules specified + to maximize memory savings. + + Args: + memory_efficient_fsdp_wrap (bool): If ``True``, will also wrap embedding and output projection layers with FSDP. + modules_to_wrap (Set[Type]): Set of module types to wrap. + + Note: + ``memory_efficient_fsdp_wrap`` memory improvements have currently only been verified on llama3 workloads + where they provide ~15% memory improvement (when used alongside AC memory efficient wrapping). Other workloads + have not been verified and may not see the same improvements. + + Returns: + FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel`` as the ``auto_wrap_policy`` + argument. Please see documentation for :const:`~torchtune.utils.FSDPPolicyType` for additional details. + """ + if memory_efficient_fsdp_wrap: + return _memory_efficient_wrap_policy(modules_to_wrap=modules_to_wrap) + else: + return ModuleWrapPolicy(modules_to_wrap) + + +def _memory_efficient_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType: + """ + A default policy for memory efficient wrapping for full finetuning using FSDP. Specifically, + this will wrap the model's token embedding and output projection into their own FSDP units to + maximize memory savings. This helps especially if these layers are particularly large, + such as due to a large embedding size. + After this is done, model will also be hierarchically wrapped + based on nn.Module types specified in ``modules_to_wrap``. This function assumes that the + input model has an attribute ``output`` that is a nn.Linear which is the model's output projection. + Args: + modules_to_wrap (Set[Type]): nn.Module types to recursively wrap + Returns: + FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel``. + """ + modules_to_wrap.add(torch.nn.Embedding) + + def llama3_wrap(module: nn.Module, recurse: bool, **kwargs): + # Label that output_proj should be wrapped individually. + if isinstance(module, modules.TransformerDecoder): + module.output._wrap = True + if recurse: + return True + + # Wrap output_proj individually. + if getattr(module, "_wrap", False): + return True + + return isinstance(module, tuple(modules_to_wrap)) + + return llama3_wrap diff --git a/training/torchtune/utils/_generation.py b/training/torchtune/utils/_generation.py new file mode 100644 index 0000000..b60fe09 --- /dev/null +++ b/training/torchtune/utils/_generation.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional + +import torch +from torchtune.modules import TransformerDecoder + + +def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor: + """Samples from a multinomial distribution.""" + q = torch.empty_like(probs).exponential_(1) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample( + logits: torch.Tensor, temperature: float = 1.0, top_k: int = None +) -> torch.Tensor: + """Generic sample from a probability distribution.""" + # scale the logits based on temperature + logits = logits / max(temperature, 1e-5) + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # select the very last value from the top_k above as the pivot + pivot = v.select(-1, -1).unsqueeze(-1) + # set everything smaller than pivot value to inf since these + # should be pruned + logits = torch.where(logits < pivot, -float("Inf"), logits) + # change logits into probabilities + probs = torch.nn.functional.softmax(logits, dim=-1) + return multinomial_sample_one(probs) + + +def generate_next_token( + model: TransformerDecoder, + input_pos: torch.Tensor, + x: torch.Tensor, + temperature: float = 1.0, + top_k: int = None, +) -> torch.Tensor: + """Generates the next tokens.""" + # model produces logits in [bsz, seq_length, vocab_size] + # we want to take the last token's logits as the input to the next model call + logits = model(x, input_pos=input_pos)[:, -1] + return sample(logits, temperature, top_k) + + +def update_stop_tokens_tracker( + tokens: torch.Tensor, stop_tokens: torch.Tensor, stop_token_reached: torch.Tensor +) -> torch.Tensor: + """Updates which sequences have reached a stop token.""" + # tokens: [bsz, 1] + # stop_tokens: [num_stop_tokens] + # stop_token_reached: [bsz] + stop_token_reached_curr = torch.isin(tokens, stop_tokens).flatten() + stop_token_reached |= stop_token_reached_curr + return stop_token_reached + + +@torch.inference_mode() +def generate( + model: TransformerDecoder, + prompt: torch.Tensor, + *, + max_generated_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + stop_tokens: Optional[List[int]] = None, + custom_generate_next_token: Optional[Callable] = None, +) -> List[List[int]]: + """ + Generates tokens from a model conditioned on a prompt. + + Args: + model (TransformerDecoder): model used for generation + prompt (torch.Tensor): tensor with the token IDs associated with the given prompt, + with shape either [seq_length] or [bsz x seq_length] + max_generated_tokens (int): number of tokens to be generated + temperature (float): value to scale the predicted logits by, default 1.0. + top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities, + default None. + stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated, + default None. + custom_generate_next_token (Optional[Callable]): If specified, we'll use the ``custom_generate_next_token function``. + This is generally only useful if you want to specify a ``torch.compile`` version of the generate next token for + performance reasons. If None, we use the default ``generate_next_token`` function. Default is None. + + Examples: + >>> model = torchtune.models.llama3.llama3_8b() + >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() + >>> prompt = tokenizer("Hi my name is") + >>> output = generate(model, prompt, max_generated_tokens=100) + >>> print(tokenizer.decode(output[0])) + Hi my name is Jeremy and I'm a friendly language model assistant! + + Returns: + List[List[int]]: collection of lists of generated tokens + """ + prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt + # convert stop tokens to tensor for easy matching + stop_tokens = ( + torch.tensor(stop_tokens, device=prompt.device) if stop_tokens else None + ) + bsz, prompt_length = prompt.size() + generated_tokens = prompt.clone() + # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop + stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) + # everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences + # that already hit a stop token + stop_token_mask = torch.ones( + (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device + ) + + if custom_generate_next_token is None: + custom_generate_next_token = generate_next_token + + # generate the first tokens conditioned on the prompt + input_pos = torch.arange(0, model.max_seq_len, device=prompt.device) + tokens = generate_next_token( + model, + input_pos=input_pos[:prompt_length], + x=prompt, + temperature=temperature, + top_k=top_k, + ) + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + + # stop early if we reach a stop token in every seq + if stop_tokens is not None: + stop_token_reached = update_stop_tokens_tracker( + tokens, stop_tokens, stop_token_reached + ) + if stop_token_reached.all().item(): + return generated_tokens.tolist() + + curr_pos = prompt_length + # if key value caches are enabled, we can incrementally decode + incremental_decoding = model.caches_are_enabled() + for _ in range(max_generated_tokens - 1): + # update stop_token_mask if we reached a stop token in a previous step + # by appending the logical not of stop_token_reached to the end of the mask + # reshaped to be bsz first + if stop_tokens is not None: + stop_token_mask = torch.cat( + [stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1 + ) + + # if incremental decoding is enabled, we can use the current position + # otherwise, we take the whole sequence up to the current position + if incremental_decoding: + curr_input_pos = input_pos[curr_pos].unsqueeze(0) + else: + curr_input_pos = input_pos[: curr_pos + 1] + tokens = generated_tokens.clone() + + tokens = custom_generate_next_token( + model, + input_pos=curr_input_pos, + x=tokens, + temperature=temperature, + top_k=top_k, + ) + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + curr_pos += 1 + + if stop_tokens is not None: + stop_token_reached = update_stop_tokens_tracker( + tokens, stop_tokens, stop_token_reached + ) + if stop_token_reached.all().item(): + break + + # mask out generated tokens in seqs that already hit a stop token + if stop_tokens is not None: + generated_tokens = generated_tokens * stop_token_mask + + return generated_tokens.tolist() diff --git a/training/torchtune/utils/_profiler.py b/training/torchtune/utils/_profiler.py new file mode 100644 index 0000000..5d3b32f --- /dev/null +++ b/training/torchtune/utils/_profiler.py @@ -0,0 +1,374 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import time +from functools import partial +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch.distributed + +from omegaconf import DictConfig +from torch._C._profiler import _ExperimentalConfig +from torch.profiler import tensorboard_trace_handler +from torchtune.utils import get_world_size_and_rank + +from torchtune.utils.logging import get_logger + +log = get_logger("INFO") + +PROFILER_KEY = "profiler" +DEFAULT_PROFILER_ACTIVITIES = { + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, +} + +DEFAULT_SCHEDULE: dict = { + "wait_steps": 5, + "warmup_steps": 5, + "active_steps": 2, + "num_cycles": 1, +} + +DEFAULT_TRACE_OPTS: dict = { + "profile_memory": False, + "with_stack": False, + "record_shapes": True, + "with_flops": False, +} + +DEFAULT_PROFILE_DIR: str = "profiler_output" + + +def _warn(msg: str): + _, rank = get_world_size_and_rank() + if rank == 0: + log.warning(msg) + + +def trace_handler( + prof: torch.profiler.profile, + output_dir: str, + metric: str = "self_cuda_time_total", + row_limit: int = 25, +): + """ + Handles export of artifacts from ``torch.profiler.profile``. + + The following artifacts are exported: + - chrome / tensorboard trace - viewable through tensorboard or perfetto.dev / chrome::/tracing + - trace event table + - memory timeline if ``profile_memory`` + - stacks if ``with_stack`` (note that ``profile_memory`` requires ``with_stack`` to be ``True``), + viewable as a flamegraph see (https://pytorch.org/docs/stable/profiler.html#torch.profiler._KinetoProfile.export_stacks). + + Notes: + - Each profiling cycle is exported as a sub-directory in output_dir + - E.g., profiling in 5-step cycle (wait=2, warmup=2, active=1, repeat=0) will result in + sub-directories iteration_5, iteration_10, etc. + - If profiling in a distributed setting, each artifact will be prefixed with rank. + - Memory timeline is only exported for rank 0 (error if exporting from multiple ranks on single node) + + See profiler documentation (https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for more details + + Args: + prof (torch.profiler.profile): instance of torch profiler to use + output_dir (str): directory to store artifacts + metric (str): metric to order trace event table by, see ``torch.profiler.profile.key_averages().table`` for + row_limit (int): number of rows to display in trace event table + + """ + world_size, rank = get_world_size_and_rank() + curr_trace_dir_name = "iteration_" + str(prof.step_num) + curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) + if not os.path.exists(curr_trace_dir): + os.makedirs(curr_trace_dir, exist_ok=True) + + # Export chrome / tensorboard trace + if rank == 0: + log.info(f"Dumping traces at step {prof.step_num}") + begin = time.monotonic() + + # Use tensorboard trace handler rather than directly exporting chrome traces since + # tensorboard doesn't seem to be able to parse traces with prof.export_chrome_trace + exporter = tensorboard_trace_handler( + curr_trace_dir, worker_name=f"rank{rank}", use_gzip=True + ) + exporter(prof) + + if rank == 0: + log.info(f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds") + + # Memory timeline sometimes fails to export + if prof.profile_memory: + if rank == 0: + try: + prof.export_memory_timeline( + f"{curr_trace_dir}/rank{rank}_memory-timeline.html" + ) + except Exception as e: + log.warn(f" Failed to export memory timeline: {e}") + + # Dump stack traces + if prof.with_stack: + prof.export_stacks(f"{curr_trace_dir}/rank{rank}_stacks.txt", metric=metric) + + # Export event averages + key_avgs = prof.key_averages( + group_by_input_shape=prof.record_shapes, group_by_stack_n=5 + ).table(sort_by=metric, row_limit=row_limit) + with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f: + print(key_avgs, file=f) + if rank == 0: + log.info(f"Saving profiling results to {curr_trace_dir}") + + # TODO: Is this necessary? + # see https://github.com/pytorch/torchtitan/blob/3050098dcee4901d88c712f9e8e9703d1735a29b/torchtitan/profiling.py#L48 + if world_size > 1: + torch.distributed.barrier() + + +class DummyProfiler: + """ + Drop-in replacement for torch.profiler.profile that functions as a nullcontext / object + with no-op methods for ``start``, ``stop``, and ``step``. + + This is helpful for instrumenting profiling in a recipe without requiring changes to the + code independent of whether profiling is on / off. + + E.g., + ``` + profiler = DummyProfiler() + #profiler = torch.profiler.profile() + + # Below is same regardless of profiler object type + with profiler as prof: + for epoch in epochs: + for batch in batches: + train.step() + prof.step() + + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def start(self): + pass + + def stop(self): + pass + + def step(self): + pass + + +def setup_torch_profiler( + enabled: bool = False, + cpu: bool = True, + cuda: bool = True, + profile_memory: bool = DEFAULT_TRACE_OPTS["profile_memory"], + with_stack: bool = DEFAULT_TRACE_OPTS["with_stack"], + record_shapes: bool = DEFAULT_TRACE_OPTS["record_shapes"], + with_flops: bool = DEFAULT_TRACE_OPTS["with_flops"], + # `torch.profiler.schedule` args - note we defer setting these to enable more fine-grained + # warnings within this setup function + wait_steps: Optional[int] = None, + warmup_steps: Optional[int] = None, + active_steps: Optional[int] = None, + num_cycles: Optional[int] = None, + output_dir: Optional[str] = None, +) -> Tuple[torch.profiler.profile, DictConfig]: + """ + Sets up :class:`~torch.profiler.profile` and returns the profiler config with post-setup updates. + + The profiler config can be provided in configs under the ``profiler`` key with the following layout: + + .. code-block:: yaml + + profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: bool + # Output directory of trace artifacts + output_dir: str + + # torch.profiler.ProfilerActivity types to trace + cpu: bool + cuda: bool + + # Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # torch.profiler.schedule args + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + + The profiler schedule updates with respect to an optimizer step (e.g., if + ``gradient_accumulation = 2``, then the profiler will step every 2 batches). + + Sensible defaults will be chosen if the config is missing options: + + - If no activities are specified, profiler will default to CPU + CUDA + - If no schedule is specified, profiler will default to ``DEFAULT_SCHEDULE`` + - Certain options will be overridden (``with_stack`` and ``record_shapes``) \ + depending on requirements of other options (e.g., ``profile_memory`` requires \ + ``with_stack`` and ``record_shapes``). + + + Note: + - Enabling the profiler will result in training speed reduction. + - Setting ``profile_memory: True`` will generate large trace files. + - The profiler schedule is context dependent. Calling ``profiler.step()`` \ + at each batch iteration but **outside** the gradient accumulation scope will \ + ``step`` the profiler each forward / backward step. Calling ``profiler.step()`` \ + each batch iteration but **within** the gradient accumulation scope will ``step`` \ + the profiler each optimizer update step such that each ``step`` contains multiple \ + forward / backward passes. + + Args: + enabled (bool): Enable pytorch profiler. Default is False. + cpu (bool): Enable cpu profiling. Default is True. + cuda (bool): Enable cuda profiling. Default is True. + profile_memory (bool): Profile memory usage. Default is False. + with_stack (bool): Profile stack. Default is False. + record_shapes (bool): Record shapes. Default is True. + with_flops (bool): Profile flops. Default is False. + wait_steps (Optional[int]): Wait time in steps. Maps to ``wait`` kwarg of ``torch.profiler.schedule``. + warmup_steps (Optional[int]): Warmup time in steps. Maps to ``warmup`` kwarg of ``torch.profiler.schedule``. + active_steps (Optional[int]): Active time in steps. Maps to ``active`` kwarg of ``torch.profiler.schedule``. + num_cycles (Optional[int]): Number of profiling cycles. Maps to ``repeat`` kwarg of ``torch.profiler.schedule``. + output_dir (Optional[str]): Tracing file output path. + + Returns: + Tuple[torch.profiler.profile, DictConfig] + """ + + if not enabled: + _warn(" Profiling disabled.") + return DummyProfiler(), DictConfig({"enabled": False}) + + # Set up profiler activities + activities = [] + if cpu: + activities.append(torch.profiler.ProfilerActivity.CPU) + if cuda: + activities.append(torch.profiler.ProfilerActivity.CUDA) + if len(activities) == 0: + _warn("No activities specified, defaulting to CPU + CUDA") + activities = DEFAULT_PROFILER_ACTIVITIES + cpu = cuda = True + + # Check for schedule + # 1) If no schedule is provided, set to DEFAULT_SCHEDULE + # 2) else check for missing keys and warn if any are missing, setting these to defaults + # Note that this might result in code duplication if these checks are already done in the `recipe` + # However, we retain this checks in the case that the _setup_profiler section of the `recipe` does not implement these checks + + # Set up profiler schedule + use_default_schedule = not any( + [ + wait_steps is not None, + warmup_steps is not None, + active_steps is not None, + num_cycles is not None, + ] + ) + + # Use default schedule if None, else validate that schedule is valid and can be passed to `instantiate` + if use_default_schedule: + schedule_args = DEFAULT_SCHEDULE + _warn( + " No schedule found in config, defaulting to {}".format( + ", ".join(f"{k} = {schedule_args[k]}" for k in schedule_args.keys()) + ) + ) + else: + schedule_args = { + "wait_steps": wait_steps, + "warmup_steps": warmup_steps, + "active_steps": active_steps, + "num_cycles": num_cycles, + } + missing_keys = [k for k in schedule_args.keys() if schedule_args[k] is None] + if len(missing_keys) > 0: + for k in missing_keys: + schedule_args[k] = DEFAULT_SCHEDULE[k] + _warn( + " Missing keys in torch profiler schedule {}: defaulting to {}".format( + ", ".join(missing_keys), + ", ".join(f"{k} = {schedule_args[k]}" for k in missing_keys), + ) + ) + schedule = torch.profiler.schedule( + wait=schedule_args["wait_steps"], + warmup=schedule_args["warmup_steps"], + active=schedule_args["active_steps"], + repeat=schedule_args["num_cycles"], + ) + + # profile_memory requires with_stack and record_shapes, hence we override these if profile_memory is True + # See torch.profiler.profiler._memory_profile + if profile_memory: + _warn( + "`profile_memory` requires `with_stack` and `record_shapes`, these will be enabled since `profile_memory` is True" + ) + with_stack = with_stack or profile_memory + record_shapes = record_shapes or profile_memory + # experimental config is needed to export stacks: see https://github.com/pytorch/pytorch/issues/100253 + experimental_config = _ExperimentalConfig(verbose=True) if with_stack else None + + # Handle exporting of trace, memory timeline and other profiler artifacts + if output_dir is None: + _warn( + f" No output directory found in profiler config, defaulting to {DEFAULT_PROFILE_DIR}" + ) + output_dir = DEFAULT_PROFILE_DIR + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_dir = str(output_dir) + + # trace_handler manages the export of profiler artifacts + # this callback will be triggered after **each** profiling cycle + callback = partial(trace_handler, output_dir=output_dir) + + profiler = torch.profiler.profile( + activities=activities, + profile_memory=profile_memory, + with_stack=with_stack, + record_shapes=record_shapes, + with_flops=with_flops, + schedule=schedule, + experimental_config=experimental_config, + on_trace_ready=callback, + ) + + profiler_cfg = DictConfig( + { + "enabled": enabled, + "output_dir": output_dir, + "cpu": cpu, + "cuda": cuda, + "profile_memory": profile_memory, + "with_stack": with_stack, + "record_shapes": record_shapes, + "with_flops": with_flops, + **schedule_args, + } + ) + + return (profiler, profiler_cfg) diff --git a/training/torchtune/utils/_version.py b/training/torchtune/utils/_version.py new file mode 100644 index 0000000..830a8ba --- /dev/null +++ b/training/torchtune/utils/_version.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + + +def torch_version_ge(version: str) -> bool: + """ + Check if torch version is greater than or equal to the given version. + + Args: + version (str): The torch version to compare against + + Returns: + bool: True if torch version is greater than or equal to the given version. + + Example: + >>> print(torch.__version__) + 2.4.0 + >>> torch_version_ge("2.0") + True + """ + return version in torch.__version__ or torch.__version__ >= version diff --git a/training/torchtune/utils/activations.py b/training/torchtune/utils/activations.py new file mode 100644 index 0000000..37aab68 --- /dev/null +++ b/training/torchtune/utils/activations.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, + CheckpointImpl, +) +from torch.utils.checkpoint import checkpoint + + +# Uses PTD FSDP AC wrapper +# currently selective per layer checkpointing are supported +def checkpoint_wrapper(module, ac_mode, ac_style): + + if ac_mode == "full": + return ptd_checkpoint_wrapper( + module, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + checkpoint_fn=checkpoint, + use_reentrant=False, + preserve_rng_state=False, + ) + + # selective layer checkpointing...some checks in case we receive '2' or 2... + elif ac_mode == "selective": + """enables selective checkpointing of candidate layers. + Usage: + 'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint. + 1 == checkpointing every one (all). + 2 == checkpoint every 2nd one + """ + every_x_layer = int(ac_style) + + if not (every_x_layer >= 0): + raise ValueError( + f"Selective layer AC policy (every_x_layer) expects a positive integer, received {every_x_layer}" + ) + + checkpoint_wrapper.__dict__.setdefault("_count", 0) + + checkpoint_wrapper._count += 1 + if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0: + return ptd_checkpoint_wrapper( + module, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + checkpoint_fn=checkpoint, + use_reentrant=False, + preserve_rng_state=False, + ) + # skip activation checkpointing and store activations for this layer + else: + return module + + else: + raise NotImplementedError( + "Unknown AC type or AC config. Only selective op and selective layer ac implemented currently." + ) + + +def apply_selective_activation_checkpointing( + model: nn.Module, + ac_mode: str, + ac_option: Optional[Union[int, str]], +) -> None: + """Utility to setup activation checkpointing and wrap the model for checkpointing. + + Args: + model (nn.Module): Model to setup activation checkpointing. + ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective'] + ac_option (Optional[Union[int, str]]): Activation checkpointing option. + - If ac_mode is 'selective', ac_option can be an integer or a string + representing the number of layers to checkpoint. + - If ac_mode is 'selective' and ac_option is 'op', then selective op ac is run. + - If ac_mode is 'none' or 'full, ac_option is ignored. + """ + + for layer_id, transformer_block in enumerate(model.layers): + if ac_mode in ("full", "selective"): + + transformer_block = checkpoint_wrapper( + transformer_block, + ac_mode, + ac_option, + ) + model.layers[layer_id] = transformer_block diff --git a/training/torchtune/utils/argparse.py b/training/torchtune/utils/argparse.py new file mode 100644 index 0000000..466832d --- /dev/null +++ b/training/torchtune/utils/argparse.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from argparse import Namespace +from typing import List, Tuple + +from omegaconf import OmegaConf + + +class TuneRecipeArgumentParser(argparse.ArgumentParser): + """ + A helpful utility subclass of the ``argparse.ArgumentParser`` that + adds a builtin argument "config". The config argument takes a file path to a YAML file + and loads in argument defaults from said file. The YAML file must only contain + argument names and their values and nothing more, it does not have to include all of the + arguments. These values will be treated as defaults and can still be overridden from the + command line. Everything else works the same as the base ArgumentParser and you should + consult the docs for more info: https://docs.python.org/3/library/argparse.html. + + Note: + This class uses "config" as a builtin argument so it is not available to use. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + super().add_argument( + "--config", + type=str, + help="Path/name of a yaml file with recipe args", + required=True, + ) + + def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]: + """This acts the same as the base parse_known_args but will first load in defaults from + from the config yaml file if it is provided. The command line args will always take + precident over the values in the config file. All other parsing method, such as parse_args, + internally call this method so they will inherit this property too. For more info see + the docs for the base method: https://docs.python.org/3/library/argparse.html#the-parse-args-method. + """ + namespace, unknown_args = super().parse_known_args(*args, **kwargs) + + unknown_flag_args = [arg for arg in unknown_args if arg.startswith("--")] + if unknown_flag_args: + raise ValueError( + f"Additional flag arguments not supported: {unknown_flag_args}. Please use --config or key=value overrides" + ) + + config = OmegaConf.load(namespace.config) + assert "config" not in config, "Cannot use 'config' within a config file" + self.set_defaults(**config) + + namespace, unknown_args = super().parse_known_args(*args, **kwargs) + del namespace.config + + return namespace, unknown_args diff --git a/training/torchtune/utils/collate.py b/training/torchtune/utils/collate.py new file mode 100644 index 0000000..e0b779f --- /dev/null +++ b/training/torchtune/utils/collate.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, List + +import torch + +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX + + +def padded_collate( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, +) -> Dict[str, torch.Tensor]: + """Pad a batch of sequences to the longest sequence length in the batch, and + convert integer lists to tensors. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Dict[str, torch.Tensor]: Collated input and label tensors. + + Example: + >>> token_pairs = [ + >>> {"tokens": [1, 2, 3], "labels": [4, 5, 6]}, + >>> {"tokens": [7,], "labels": [10,]}, + >>> ] + >>> collated = padded_collate( + >>> batch=token_pairs, + >>> padding_idx=padding_idx, + >>> ignore_idx=ignore_idx, + >>> ) + >>> collated["tokens"] + >>> tensor([[1, 2, 3], [7, 0, 0]]) + >>> collated["labels"] + >>> tensor([[4, 5, 6], [10, -100, -100]]) + """ + input_ids = pad_sequence( + [torch.tensor(x["tokens"]) for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + labels = pad_sequence( + [torch.tensor(x["labels"]) for x in batch], + batch_first=True, + padding_value=ignore_idx, + ) + + input_ids_seq_len = input_ids.shape[-1] + labels_seq_len = labels.shape[-1] + + # Hack to pad correctly and not use max_seq_len, which is costly + if input_ids_seq_len > labels_seq_len: + labels = F.pad( + labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx + ) + elif labels_seq_len > input_ids_seq_len: + input_ids = F.pad( + input_ids, + (0, labels_seq_len - input_ids_seq_len), + value=padding_idx, + ) + return {"tokens": input_ids.long(), "labels": labels.long()} diff --git a/training/torchtune/utils/constants.py b/training/torchtune/utils/constants.py new file mode 100644 index 0000000..30b4a7a --- /dev/null +++ b/training/torchtune/utils/constants.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Keys used during checkpoint load and checkpoint save. +""" + +# adapter config containing info about LoRA modules, rank, alpha +ADAPTER_CONFIG = "adapter_config" +# key used for adapter weights such as LoRA weights +ADAPTER_KEY = "adapter" +# number of epochs completed thus far +EPOCHS_KEY = "epochs_run" +MAX_STEPS_KEY = "max_steps_per_epoch" +MODEL_KEY = "model" +OPT_KEY = "optimizer" +SEED_KEY = "seed" +# total number of epochs for training; resumed training runs for +# (total_epochs - epochs_run) number of epochs +TOTAL_EPOCHS_KEY = "total_epochs" +# number of steps completed thus far - for PPO +STEPS_KEY = "steps_run" +# rng state for ensuring correct training resuming in PPO +RNG_KEY = "rng_state" diff --git a/training/torchtune/utils/logging.py b/training/torchtune/utils/logging.py new file mode 100644 index 0000000..4a5b946 --- /dev/null +++ b/training/torchtune/utils/logging.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + + +def get_logger(level: Optional[str] = None) -> logging.Logger: + """ + Get a logger with a stream handler. + + Args: + level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. + + Returns: + logging.Logger: The logger. + """ + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(logging.StreamHandler()) + if level is not None: + level = getattr(logging, level.upper()) + logger.setLevel(level) + return logger diff --git a/training/torchtune/utils/memory.py b/training/torchtune/utils/memory.py new file mode 100644 index 0000000..41ad9de --- /dev/null +++ b/training/torchtune/utils/memory.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import logging + +from typing import Any, Callable, Dict, Set, Type, Union + +import torch + +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torchtune.utils.logging import get_logger + +_log: logging.Logger = get_logger() + +ACWrapPolicyType: Type = Union[Set[Type], Callable[[nn.Module, bool, int], bool]] + + +def set_activation_checkpointing( + model: nn.Module, auto_wrap_policy: ACWrapPolicyType, **kwargs +) -> None: + """Utility to apply activation checkpointing to the passed-in model. + + Args: + model (nn.Module): Model to apply activation checkpointing to. + auto_wrap_policy (ACWrapPolicyType): Policy to wrap module. + This can either be a set of ``nn.Module`` types, in which case, modules of the specified type(s) + will be wrapped individually with activation checkpointing, or a ``callable`` policy describing + how to wrap the model with activation checkpointing. For more information on authoring custom + policies, please see this tutorial: + https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html#transformer-wrapping-policy. + **kwargs: additional arguments to pass to ``torch.distributed`` activation checkpointing. + """ + if isinstance(auto_wrap_policy, set): + auto_wrap_policy = ModuleWrapPolicy(auto_wrap_policy) + apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs) + + +def cleanup_before_training() -> None: + """ + Call gc collect, empty CUDA cache, and reset peak memory stats. + """ + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + +class OptimizerInBackwardWrapper: + """ + A bare-bones class meant for checkpoint save and load for optimizers running + in backward. Usage is limited to the following: + + Note: + This wrapper is only meant to be used for single-device use cases. + Distributed use cases such as FSDP, which require specialized optimizer state checkpointing, are not supported. + + Args: + optim_map (Dict[str, torch.optim.Optimizer]): Mapping from parameter names to optimizers. + + Example: + >>> optim_dict = { + >>> p: config.instantiate(cfg_optimizer, [p]) + >>> for p in self._model.parameters() + >>> } + >>> + >>> # Save checkpoint + >>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict() + >>> torch.save("/tmp/optim_ckpt", ckpt) + >>> + >>> # Load checkpoint + >>> placeholder_optim_dict = { + >>> p: config.instantiate(cfg_optimizer, [p]) + >>> for p in self._model.parameters() + >>> } + >>> + >>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict) + >>> + >>> # load_state_dict expects a dict produced by this class's + >>> # state_dict method. + >>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt")) + >>> # placeholder_optim_dict now has updated optimizer states. + + """ + + def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]): + self.optim_map = optim_map + + def state_dict(self) -> Dict[str, Any]: + """ + Returns a state dict mapping parameter names to optimizer states. This + state_dict is only loadable by this same class. + + Returns: + Dict[str, Any]: state dict mapping parameter names to optimizer states. + """ + return {p: opt.state_dict() for p, opt in self.optim_map.items()} + + def load_state_dict(self, optim_ckpt_map: Dict[str, Any]): + """ + Load optimizer states from a state dict produced by this class's + state_dict method. + + Args: + optim_ckpt_map (Dict[str, Any]): state dict mapping parameter names to optimizer states. + + Raises: + RuntimeError: If the optimizer state dict does not contain all the expected parameters. + """ + params_covered = set() + for param_name in optim_ckpt_map.keys(): + if param_name not in self.optim_map: + raise RuntimeError( + f"Trying to load optimizer state for unexpected param {param_name}" + ) + self.optim_map[param_name].load_state_dict(optim_ckpt_map[param_name]) + params_covered.add(param_name) + # Ensure all params have been loaded into, report missing params + missing_params = set(self.optim_map.keys()) - params_covered + if missing_params: + raise RuntimeError( + f"Expected to load optimizer state for params {missing_params}!" + ) + + def get_optim_key(self, key: str) -> Any: + """ + Returns value of key from an arbitrary optimizer running in backward. Note that + this assumes all optimizer in backwards have the same value for the key, i.e., + are initialized with the same hyperparameters. + """ + return list(self.optim_map.values())[0].param_groups[0][key] + + +def create_optim_in_bwd_wrapper( + model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer] +) -> OptimizerInBackwardWrapper: + """ + Create a wrapper for optimizer step running in backward. + + Args: + model (torch.nn.Module): Model that contains parameters that are being optimized. For now, + it is assumed that all parameters being optimized belong to a single top-level model. + ``named_parameters`` attribute of ``model`` will be accessed to look up parameter names for + parameters being optimized. + optim_dict (Dict[torch.nn.Parameter, torch.optim.Optimizer]): Mapping from + parameters to optimizers. + + Returns: + ``OptimizerInBackwardWrapper``: Wrapper for optimizer states running in backward. + """ + return OptimizerInBackwardWrapper( + {n: optim_dict[p] for n, p in model.named_parameters()} + ) + + +def register_optim_in_bwd_hooks( + model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer] +) -> None: + """ + Register hooks for optimizer step running in backward. + + When fusing the optimizer step into backward, we need to call ``.step()`` on the optimizer + for a given parameter as soon as its gradient is ready. This utility registers post-accumulate-grad + hooks on all parameters in the model to achieve this. + + Args: + model (torch.nn.Module): Model whose parameters will be optimized. Note that currently + hooks for ALL parameters in the model will be registered. + optim_dict (Dict[torch.nn.Parameter, torch.optim.Optimizer]): Mapping from + parameters to optimizers. + """ + + def optim_step(param) -> None: + optim_dict[param].step() + optim_dict[param].zero_grad() + + for p in model.parameters(): + p.register_post_accumulate_grad_hook(optim_step) + + +def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict: + """ + Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will + also reset CUDA's peak memory tracking. This is useful to get data around relative use of peak + memory (e.g. peak memory during model init, during forward, etc) and optimize memory for + individual sections of training. + + Args: + device (torch.device): Device to get memory summary for. Only CUDA devices are supported. + reset_stats (bool): Whether to reset CUDA's peak memory tracking. + + Returns: + Dict[str, float]: A dictionary containing the peak memory active, peak memory allocated, + and peak memory reserved. This dict is useful for logging memory stats. + + Raises: + ValueError: If the passed-in device is not CUDA. + """ + if device.type != "cuda": + raise ValueError( + f"Logging memory stats is only supported on CUDA devices, got {device}" + ) + + peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9 + peak_mem_alloc = torch.cuda.max_memory_allocated(device) / 1e9 + peak_mem_reserved = torch.cuda.max_memory_reserved(device) / 1e9 + + if reset_stats: + torch.cuda.reset_peak_memory_stats(device) + + memory_stats = { + "peak_memory_active": peak_memory_active, + "peak_memory_alloc": peak_mem_alloc, + "peak_memory_reserved": peak_mem_reserved, + } + return memory_stats + + +def log_memory_stats(stats: Dict[str, float]) -> None: + """ + Logs a dict containing memory stats to the logger. ``stats`` should contain the fields + ``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` as + returned by :func:`torchtune.utils.get_memory_stats`. + + Args: + stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory + allocated, and peak memory reserved stats. + """ + _log.info( + "Memory stats after model init:" + f"\n\tGPU peak memory allocation: {stats['peak_memory_alloc']:.2f} GB" + f"\n\tGPU peak memory reserved: {stats['peak_memory_reserved']:.2f} GB" + f"\n\tGPU peak memory active: {stats['peak_memory_active']:.2f} GB" + ) diff --git a/training/torchtune/utils/metric_logging.py b/training/torchtune/utils/metric_logging.py new file mode 100644 index 0000000..59c9f83 --- /dev/null +++ b/training/torchtune/utils/metric_logging.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys +import time +from pathlib import Path + +from typing import Mapping, Optional, Union + +from numpy import ndarray +from omegaconf import DictConfig, OmegaConf +from torch import Tensor + +from torchtune.utils import get_logger +from torchtune.utils._distributed import get_world_size_and_rank +from typing_extensions import Protocol + +Scalar = Union[Tensor, ndarray, int, float] + +log = get_logger("DEBUG") + + +class MetricLoggerInterface(Protocol): + """Abstract metric logger.""" + + def log( + self, + name: str, + data: Scalar, + step: int, + ) -> None: + """Log scalar data. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record + """ + pass + + def log_config(self, config: DictConfig) -> None: + """Logs the config + + Args: + config (DictConfig): config to log + """ + pass + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + """Log multiple scalar values. + + Args: + payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record + """ + pass + + def close(self) -> None: + """ + Close log resource, flushing if necessary. + Logs should not be written after `close` is called. + """ + pass + + +class DiskLogger(MetricLoggerInterface): + """Logger to disk. + + Args: + log_dir (str): directory to store logs + filename (Optional[str]): optional filename to write logs to. + Default: None, in which case log_{unixtimestamp}.txt will be used. + **kwargs: additional arguments + + Warning: + This logger is not thread-safe. + + Note: + This logger creates a new file based on the current time. + """ + + def __init__(self, log_dir: str, filename: Optional[str] = None, **kwargs): + self.log_dir = Path(log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + if not filename: + unix_timestamp = int(time.time()) + filename = f"log_{unix_timestamp}.txt" + self._file_name = self.log_dir / filename + self._file = open(self._file_name, "a") + print(f"Writing logs to {self._file_name}") + + def path_to_log_file(self) -> Path: + return self._file_name + + def log(self, name: str, data: Scalar, step: int) -> None: + self._file.write(f"Step {step} | {name}:{data}\n") + self._file.flush() + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + self._file.write(f"Step {step} | ") + for name, data in payload.items(): + self._file.write(f"{name}:{data} ") + self._file.write("\n") + self._file.flush() + + def __del__(self) -> None: + self._file.close() + + def close(self) -> None: + self._file.close() + + +class StdoutLogger(MetricLoggerInterface): + """Logger to standard output.""" + + def log(self, name: str, data: Scalar, step: int) -> None: + print(f"Step {step} | {name}:{data}") + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + print(f"Step {step} | ", end="") + for name, data in payload.items(): + print(f"{name}:{data} ", end="") + print("\n", end="") + + def __del__(self) -> None: + sys.stdout.flush() + + def close(self) -> None: + sys.stdout.flush() + + +class WandBLogger(MetricLoggerInterface): + """Logger for use w/ Weights and Biases application (https://wandb.ai/). + For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init. + + Args: + project (str): WandB project name. Default is `torchtune`. + entity (Optional[str]): WandB entity name. If you don't specify an entity, + the run will be sent to your default entity, which is usually your username. + group (Optional[str]): WandB group name for grouping runs together. If you don't + specify a group, the run will be logged as an individual experiment. + log_dir (Optional[str]): WandB log directory. If not specified, use the `dir` + argument provided in kwargs. Else, use root directory. + **kwargs: additional arguments to pass to wandb.init + + Example: + >>> from torchtune.utils.metric_logging import WandBLogger + >>> logger = WandBLogger(project="my_project", entity="my_entity", group="my_group") + >>> logger.log("my_metric", 1.0, 1) + >>> logger.log_dict({"my_metric": 1.0}, 1) + >>> logger.close() + + Raises: + ImportError: If ``wandb`` package is not installed. + + Note: + This logger requires the wandb package to be installed. + You can install it with `pip install wandb`. + In order to use the logger, you need to login to your WandB account. + You can do this by running `wandb login` in your terminal. + """ + + def __init__( + self, + project: str = "torchtune", + entity: Optional[str] = None, + group: Optional[str] = None, + log_dir: Optional[str] = None, + **kwargs, + ): + try: + import wandb + except ImportError as e: + raise ImportError( + "``wandb`` package not found. Please install wandb using `pip install wandb` to use WandBLogger." + "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'." + ) from e + self._wandb = wandb + + # Use dir if specified, otherwise use log_dir. + self.log_dir = kwargs.pop("dir", log_dir) + + _, self.rank = get_world_size_and_rank() + + if self._wandb.run is None and self.rank == 0: + # we check if wandb.init got called externally, + run = self._wandb.init( + project=project, + entity=entity, + group=group, + dir=self.log_dir, + **kwargs, + ) + + if self._wandb.run: + self._wandb.run._label(repo="torchtune") + + # define default x-axis (for latest wandb versions) + if getattr(self._wandb, "define_metric", None): + self._wandb.define_metric("global_step") + self._wandb.define_metric("*", step_metric="global_step", step_sync=True) + + self.config_allow_val_change = kwargs.get("allow_val_change", False) + + def log_config(self, config: DictConfig) -> None: + """Saves the config locally and also logs the config to W&B. The config is + stored in the same directory as the checkpoint. You can + see an example of the logged config to W&B in the following link: + https://wandb.ai/capecape/torchtune/runs/6053ofw0/files/torchtune_config_j67sb73v.yaml + + Args: + config (DictConfig): config to log + """ + if self._wandb.run: + resolved = OmegaConf.to_container(config, resolve=True) + self._wandb.config.update( + resolved, allow_val_change=self.config_allow_val_change + ) + try: + output_config_fname = Path( + os.path.join( + config.checkpointer.checkpoint_dir, + "torchtune_config.yaml", + ) + ) + OmegaConf.save(config, output_config_fname) + + log.info(f"Logging {output_config_fname} to W&B under Files") + self._wandb.save( + output_config_fname, base_path=output_config_fname.parent + ) + + except Exception as e: + log.warning( + f"Error saving {output_config_fname} to W&B.\nError: \n{e}." + "Don't worry the config will be logged the W&B workspace" + ) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self._wandb.run: + self._wandb.log({name: data, "global_step": step}) + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + if self._wandb.run: + self._wandb.log({**payload, "global_step": step}) + + def __del__(self) -> None: + if self._wandb.run: + self._wandb.finish() + + def close(self) -> None: + if self._wandb.run: + self._wandb.finish() + + +class TensorBoardLogger(MetricLoggerInterface): + """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). + + Args: + log_dir (str): TensorBoard log directory + organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current + run. Having sub-directories allows you to compare logs across runs. When TensorBoard is + passed a logdir at startup, it recursively walks the directory tree rooted at logdir looking for + subdirectories that contain tfevents data. Every time it encounters such a subdirectory, + it loads it as a new run, and the frontend will organize the data accordingly. + Recommended value is `True`. Run `tensorboard --logdir my_log_dir` to view the logs. + **kwargs: additional arguments + + Example: + >>> from torchtune.utils.metric_logging import TensorBoardLogger + >>> logger = TensorBoardLogger(log_dir="my_log_dir") + >>> logger.log("my_metric", 1.0, 1) + >>> logger.log_dict({"my_metric": 1.0}, 1) + >>> logger.close() + + Note: + This utility requires the tensorboard package to be installed. + You can install it with `pip install tensorboard`. + In order to view TensorBoard logs, you need to run `tensorboard --logdir my_log_dir` in your terminal. + """ + + def __init__(self, log_dir: str, organize_logs: bool = True, **kwargs): + from torch.utils.tensorboard import SummaryWriter + + self._writer: Optional[SummaryWriter] = None + _, self._rank = get_world_size_and_rank() + + # In case organize_logs is `True`, update log_dir to include a subdirectory for the + # current run + self.log_dir = ( + os.path.join(log_dir, f"run_{self._rank}_{time.time()}") + if organize_logs + else log_dir + ) + + # Initialize the log writer only if we're on rank 0. + if self._rank == 0: + self._writer = SummaryWriter(log_dir=self.log_dir) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self._writer: + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + for name, data in payload.items(): + self.log(name, data, step) + + def __del__(self) -> None: + if self._writer: + self._writer.close() + self._writer = None + + def close(self) -> None: + if self._writer: + self._writer.close() + self._writer = None diff --git a/training/torchtune/utils/pooling.py b/training/torchtune/utils/pooling.py new file mode 100644 index 0000000..ffbf5fa --- /dev/null +++ b/training/torchtune/utils/pooling.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + + +def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: + """ + Returns the sequence lengths for each batch element, excluding masked tokens. + Args: + mask (torch.Tensor): Boolean mask with shape [b x s], where True indicates a value to be masked out + - this is usually a mask for padding tokens, where True indicates a padding token + + Returns: + Tensor: Sequence indexes logits with shape [b] + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + + Example: + >>> input_ids = torch.tensor([ + >>> [2, 4, 0, 0], + >>> [2, 4, 6, 0], + >>> [2, 4, 6, 9] + >>> ]) + >>> get_last_unmasked_token_idx(input_ids == 0) + >>> tensor([1, 2, 3]) + """ + # calculate per-batch-element sequence lengths by finding last valid tokens + if mask.any(): + sequence_lengths = ( + (~mask).sum(-1).sub(1).clip(0).to(mask.device, dtype=torch.long) + ) + else: + sequence_lengths = torch.full( + (mask.shape[0],), mask.shape[1] - 1, dtype=torch.long, device=mask.device + ) + + return sequence_lengths diff --git a/training/torchtune/utils/precision.py b/training/torchtune/utils/precision.py new file mode 100644 index 0000000..244565c --- /dev/null +++ b/training/torchtune/utils/precision.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Dict, Generator, Iterable, Optional, Tuple + +import torch + +from torchtune.utils.logging import get_logger + +log = get_logger() + +PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + "fp64": torch.float64, +} + + +def _set_float32_precision(precision: str = "high") -> None: + """Sets the precision of float32 matrix multiplications and convolution operations. + + For more information, see the PyTorch docs: + - https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html + - https://pytorch.org/docs/stable/backends.html#torch.backends.cudnn.allow_tf32 + + Args: + precision (str): The setting to determine which datatypes to use for matrix multiplication and convolution operations. + """ + if not torch.cuda.is_available(): # Not relevant for non-CUDA devices + return + # set precision for matrix multiplications + torch.set_float32_matmul_precision(precision) + # set precision for convolution operations + if precision == "highest": + torch.backends.cudnn.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + + +def verify_bf16_support() -> bool: + """ + Check that bf16 is available on this hardware. Requirements: + - CUDA is available and supports bf16 + - CUDA version >= 11 + - CUDA compute capability >= 8 + - NCCL is available and version >= 2.10 + + Returns: + bool: True if bf16 is available, False otherwise. + + """ + return ( + torch.cuda.is_available() + and torch.cuda.is_bf16_supported() + and torch.distributed.is_nccl_available() + and torch.cuda.nccl.version() >= (2, 10) + ) + + +def get_dtype( + dtype: Optional[str] = None, device: Optional[torch.device] = None +) -> torch.dtype: + """Get the torch.dtype corresponding to the given precision string. If no string is passed, + we will default to torch.float32. + + Note: + If bf16 precision is requested with a CUDA device, we verify whether the device indeed supports + bf16 kernels. If not, a ``RuntimeError`` is raised. + + Args: + dtype (Optional[str]): The precision dtype. Default: ``None``, in which we default to torch.float32 + device (Optional[torch.device]): Device in use for training. Only CUDA and CPU + devices are supported. If a CUDA device is passed in, additional checking is done + to ensure that the device supports the requested precision. Default: ``None``, in which case + a CUDA device is assumed. + Raises: + ValueError: if precision isn't supported by the library + RuntimeError: if bf16 precision is requested but not available on this hardware. + + Returns: + torch.dtype: The corresponding torch.dtype. + + """ + + # None defaults to float32 + if dtype is None: + return torch.float32 + + # Convert to torch.dtype + torch_dtype = PRECISION_STR_TO_DTYPE.get(dtype, dtype) + + # dtype must be one of the supported precisions + if torch_dtype not in PRECISION_STR_TO_DTYPE.values(): + raise ValueError( + f"Dtype {torch_dtype} must be one of {', '.join(list(PRECISION_STR_TO_DTYPE.keys()))} for finetuning." + ) + + if ( + torch_dtype == torch.bfloat16 + and device != torch.device("cpu") + and not verify_bf16_support() + ): + raise RuntimeError( + "bf16 precision was requested but not available on this hardware. Please use fp32 precision instead." + ) + + return torch_dtype + + +@contextlib.contextmanager +def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]: + """ + Context manager to set torch's default dtype. + + Args: + dtype (torch.dtype): The desired default dtype inside the context manager. + + Returns: + ContextManager: context manager for setting default dtype. + + Example: + >>> with set_default_dtype(torch.bfloat16): + >>> x = torch.tensor([1, 2, 3]) + >>> x.dtype + torch.bfloat16 + + + """ + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(old_dtype) + + +def validate_expected_param_dtype( + named_params: Iterable[Tuple[str, torch.nn.Parameter]], dtype: torch.dtype +) -> None: + """ + Validates that all input parameters have the expected dtype. + + Args: + named_params (Iterable[Tuple[str, torch.nn.Parameter]]): Iterable of named parameters. + dtype (torch.dtype): Expected dtype. + + Raises: + ValueError: If any parameter has a different dtype than `dtype`. + """ + for name, param in named_params: + if param.dtype != dtype: + raise ValueError( + f"Parameter {name} has dtype {param.dtype}, but expected {dtype}" + ) diff --git a/training/torchtune/utils/quantization.py b/training/torchtune/utils/quantization.py new file mode 100644 index 0000000..7158332 --- /dev/null +++ b/training/torchtune/utils/quantization.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +# importing TORCH_VERSION_AFTER_2_3 because `Int8DynActInt4WeightQuantizer` +# is only available after 2.3 so we have to guard the pytorch versions to decide +# the list of supported quantizers +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 + +__all__ = [ + "get_quantizer_mode", +] + + +_quantizer_to_mode = {} +_quantizer_mode_to_disable_fake_quant = {} +_quantizer_mode_to_enable_fake_quant = {} + + +if TORCH_VERSION_AFTER_2_3: + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + __all__.append("Int8DynActInt4WeightQuantizer") + _quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" + + +if TORCH_VERSION_AFTER_2_4: + from torchao.quantization.prototype.qat import ( + disable_8da4w_fake_quant, + enable_8da4w_fake_quant, + Int8DynActInt4WeightQATQuantizer, + ) + + __all__.append("Int8DynActInt4WeightQATQuantizer") + _quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat" + _quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant + _quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant + + +def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: + """Given a quantizer object, returns a string that specifies the type of quantization. + + For example, in the case of int4 weight only quantization, we'll return "4w". + If the quantizer is not recognized as a known quantizer, we'll return None. + + Currently supported: + + - :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``) + - :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``) + + Args: + quantizer (Optional[Callable]): A callable object that implements the `quantize` method. + + Returns: + Optional[str]: The quantization mode. + """ + return _quantizer_to_mode.get(type(quantizer), None) + + +def _get_disable_fake_quant(quantizer_mode: str) -> Callable: + """Given a quantizer mode, return the corresponding function for disabling fake + quantize in a model prepared by the quantizer. + If the quantizer is not recognized as a known QAT quantizer, return None. + """ + return _quantizer_mode_to_disable_fake_quant.get(quantizer_mode, None) + + +def _get_enable_fake_quant(quantizer_mode: str) -> Callable: + """Given a quantizer mode, return the corresponding function for enabling fake + quantize in a model prepared by the quantizer. + If the quantizer is not recognized as a known QAT quantizer, return None. + """ + return _quantizer_mode_to_enable_fake_quant.get(quantizer_mode, None) diff --git a/training/torchtune/utils/seed.py b/training/torchtune/utils/seed.py new file mode 100644 index 0000000..533dbfe --- /dev/null +++ b/training/torchtune/utils/seed.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import random +from typing import Optional, Union + +import numpy as np +import torch + +from torchtune.utils._distributed import _broadcast_tensor, get_world_size_and_rank +from torchtune.utils.logging import get_logger + +_log: logging.Logger = get_logger() + + +def set_seed( + seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None +) -> int: + """Function that sets seed for pseudo-random number generators across commonly used libraries. + + This seeds PyTorch, NumPy, and the python.random module. For distributed jobs, each local process + sets its own seed, computed seed + rank. + For more details, see https://pytorch.org/docs/stable/notes/randomness.html. + + Args: + seed (Optional[int]): the integer value seed. If `None`, a random seed will be generated and set. + debug_mode (Optional[Union[str, int]]): Controls debug_mode settings for deterministic operations within PyTorch. + + * If `None`, don't set any PyTorch global values. + * If "default" or 0, don't error or warn on nondeterministic operations and additionally enable PyTorch CuDNN benchmark. + * If "warn" or 1, warn on nondeterministic operations and disable PyTorch CuDNN benchmark. + * If "error" or 2, error on nondeterministic operations and disable PyTorch CuDNN benchmark. + * For more details, see :func:`torch.set_deterministic_debug_mode` and + https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms. + + Returns: + int: the current seed + + Raises: + ValueError: If the input seed value is outside the required range. + """ + world_size, rank = get_world_size_and_rank() + max_val = np.iinfo(np.uint32).max - world_size + 1 + min_val = np.iinfo(np.uint32).min + if seed is None: + rand_seed = torch.randint(min_val, max_val, (1,)) + seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks + if seed < min_val or seed > max_val: + raise ValueError( + f"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]" + ) + local_seed = seed + rank + if rank == 0: + _log.debug( + f"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}" + ) + + torch.manual_seed(local_seed) + np.random.seed(local_seed) + random.seed(local_seed) + + if debug_mode is not None: + _log.debug(f"Setting deterministic debug mode to {debug_mode}") + torch.set_deterministic_debug_mode(debug_mode) + deterministic_debug_mode = torch.get_deterministic_debug_mode() + if deterministic_debug_mode == 0: + _log.debug("Disabling cuDNN deterministic mode") + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + else: + _log.debug("Enabling cuDNN deterministic mode") + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + return seed diff --git a/training/version.txt b/training/version.txt new file mode 100644 index 0000000..0c62199 --- /dev/null +++ b/training/version.txt @@ -0,0 +1 @@ +0.2.1