diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bbdd4259cd..a586a31068 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,6 +2,10 @@ # This includes setup.py, the README, and the CODEOWNERS file itself! /* @mosaicml/composer-team-admins +# Require team approval for code changes +/llmfoundry/ @mosaicml/composer-team-eng +/scripts/ @mosaicml/composer-team-eng + # Require admin approval to change the CI build configuration # All CI Changes should be reviewed for security /.ci/ @mosaicml/composer-team-admins diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..468099c849 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: +- package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" + open-pull-requests-limit: 5 diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 54500b674c..062aa41bf4 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -19,7 +19,7 @@ defaults: working-directory: . jobs: code-quality: - runs-on: ubuntu-20.04 + runs-on: linux-ubuntu-latest timeout-minutes: 30 strategy: matrix: @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.0.5 + ref: v0.0.9 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml deleted file mode 100644 index 3d704a29de..0000000000 --- a/.github/workflows/codeql-analysis.yml +++ /dev/null @@ -1,50 +0,0 @@ -# For most projects, this workflow file will not need changing; you simply need -# to commit it to your repository. -# -# You may wish to alter this file to override the set of languages analyzed, -# or to provide custom queries or build logic. -# -# ******** NOTE ******** -# We have attempted to detect the languages in your repository. Please check -# the `language` matrix defined below to confirm you have the correct set of -# supported CodeQL languages. -# -name: "CodeQL" - -on: - push: - branches: [main] - pull_request: - # The branches below must be a subset of the branches above - branches: [main] - schedule: - - cron: "0 9 * * 1" # Every Monday at 09:00 (9:00 AM) - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: ["python"] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] - # Learn more about CodeQL language support at https://git.io/codeql-language-support - - steps: - - name: Checkout repository - uses: actions/checkout@v2 - - name: Get composite run steps repository - uses: actions/checkout@v3 - with: - repository: mosaicml/ci-testing - ref: v0.0.5 - path: ./ci-testing - - uses: ./ci-testing/.github/actions/codeql-analysis - with: - language: ${{ matrix.language }} diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 131ec1195c..cf3581f716 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -8,7 +8,7 @@ on: jobs: coverage: timeout-minutes: 5 - runs-on: ubuntu-latest + runs-on: linux-ubuntu-latest steps: - name: Checkout Repo uses: actions/checkout@v3 @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.0.5 + ref: v0.0.9 path: ./ci-testing - uses: ./ci-testing/.github/actions/coverage with: diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 6ca10fcd47..0bb0b4087a 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -17,20 +17,13 @@ jobs: strategy: matrix: include: - - name: "2.3.0_cu121_flash2" - base_image: mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04 - dep_groups: "[gpu-flash2]" - - name: "2.3.0_cu121_flash2_aws" - base_image: mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04-aws - dep_groups: "[gpu-flash2]" + - name: "2.3.1_cu121" + base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + dep_groups: "[all]" + - name: "2.3.1_cu121_aws" + base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws + dep_groups: "[all]" steps: - - name: Maximize Build Space on Worker - uses: easimon/maximize-build-space@v4 - with: - overprovision-lvm: true - remove-dotnet: true - remove-android: true - remove-haskell: true - name: Checkout uses: actions/checkout@v3 @@ -47,6 +40,13 @@ jobs: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_PASSWORD }} + - name: Login to GHCR + uses: docker/login-action@v2 + with: + username: ${{ secrets.GHCR_USERNAME }} + password: ${{ secrets.GHCR_TOKEN }} + registry: ghcr.io + - name: Calculate Docker Image Variables run: | set -euxo pipefail @@ -60,13 +60,17 @@ jobs: if [ "${{ github.event_name }}" == "pull_request" ]; then echo "Triggered by pull_request event." STAGING_REPO="mosaicml/ci-staging" - IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + GHCR_STAGING_REPO="ghcr.io/databricks-mosaic/ci-staging" + GHCR_IMAGE_TAG="${GHCR_STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_IMAGE_TAG}" IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache" else # Triggered by push or workflow_dispatch event echo "Triggered by ${{ github.event_name }} event, releasing to prod" PROD_REPO="mosaicml/llm-foundry" - IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" + GHCR_PROD_REPO="ghcr.io/databricks-mosaic/llm-foundry" + GHCR_IMAGE_TAG="${GHCR_PROD_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_PROD_REPO}:${{matrix.name}}-latest" + IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest,${GHCR_IMAGE_TAG}" IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" fi diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 93612b7983..2c85719756 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -15,23 +15,28 @@ concurrency: cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: pytest-cpu: - uses: mosaicml/ci-testing/.github/workflows/pytest-cpu.yaml@v0.0.5 + name: ${{ matrix.name }} + runs-on: ubuntu-latest strategy: matrix: include: - - name: "cpu-2.3.0" - container: mosaicml/pytorch:2.3.0_cpu-python3.11-ubuntu20.04 + - name: "cpu-2.3.1" + pip_deps: "[all-cpu]" + container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 markers: "not gpu" pytest_command: "coverage run -m pytest" - name: ${{ matrix.name }} - if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - name: ${{ matrix.name }} - pip_deps: "[all-cpu]" - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - safe_directory: llm-foundry + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Run PR CPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.1.0 + with: + name: ${{ matrix.name }} + container: ${{ matrix.container }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + safe_directory: llm-foundry coverage: uses: ./.github/workflows/coverage.yaml name: Coverage Results diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 31af66e51f..ba1a4f9ba4 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -9,31 +9,95 @@ on: - main - release/** workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: - pytest-gpu: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.5 + pytest-gpu-1: + name: ${{ matrix.name }} + if: github.repository_owner == 'mosaicml' + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: include: - - name: "gpu-2.3.0" - container: mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04 + - name: "gpu-2.3.1-1" + container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" + pip_deps: "[all]" pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 1 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} + pytest-gpu-2: + name: ${{ matrix.name }} + if: github.repository_owner == 'mosaicml' + runs-on: linux-ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - name: "gpu-2.3.1-2" + container: mosaicml/llm-foundry:2.3.1_cu121-latest + markers: "gpu" pip_deps: "[all]" + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 2 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} + pytest-gpu-4: name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} + runs-on: linux-ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - name: "gpu-2.3.1-4" + container: mosaicml/llm-foundry:2.3.1_cu121-latest + markers: "gpu" + pip_deps: "[all]" + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 4 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 144e3f1ad3..c09f9bb7a5 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -14,7 +14,7 @@ jobs: name: Build and Publish llm-foundry PyPI Package needs: - code-quality - runs-on: ubuntu-latest + runs-on: linux-ubuntu-latest steps: - name: Checkout source uses: actions/checkout@v3 diff --git a/.github/workflows/smoketest.yaml b/.github/workflows/smoketest.yaml index ae8d5911a8..d38849cddc 100644 --- a/.github/workflows/smoketest.yaml +++ b/.github/workflows/smoketest.yaml @@ -18,7 +18,7 @@ defaults: working-directory: . jobs: smoketest: - runs-on: ubuntu-20.04 + runs-on: linux-ubuntu-latest timeout-minutes: 20 strategy: matrix: @@ -32,7 +32,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.0.5 + ref: v0.0.9 path: ./ci-testing - uses: ./ci-testing/.github/actions/smoketest with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc2e3f55cd..b45021dd8c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,17 +77,6 @@ repos: hooks: - id: docformatter args: [--in-place, --wrap-summaries=80, --wrap-descriptions=80] -- repo: https://github.com/PyCQA/pydocstyle - hooks: - - id: pydocstyle - name: pydocstyle - entry: pydocstyle - language: python - types: [python] - exclude: (.ci|.github) - additional_dependencies: - - toml - rev: 6.1.1 - repo: https://github.com/adrienverge/yamllint.git rev: v1.28.0 hooks: diff --git a/Dockerfile b/Dockerfile index 683ab6dfb0..cee7063cdd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,13 +7,15 @@ FROM $BASE_IMAGE ARG BRANCH_NAME ARG DEP_GROUPS +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" + # Check for changes in setup.py. # If there are changes, the docker cache is invalidated and a fresh pip installation is triggered. ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py setup.py RUN rm setup.py # Install TransformerEngine -RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@05eb6deb31c1b48e9f4380d18fe95f3c38e84335 +RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@b5a7c9f # Install and uninstall foundry to cache foundry requirements RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git diff --git a/README.md b/README.md index 70436271dd..e8a6708c5a 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ DBRX is a state-of-the-art open source LLM trained by Databricks Mosaic team. It | DBRX Base | 32768 | https://huggingface.co/databricks/dbrx-base | | DBRX Instruct | 32768 | https://huggingface.co/databricks/dbrx-instruct | -Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). +Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/blob/main/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). For more information about the DBRX models, see https://github.com/databricks/dbrx. @@ -113,8 +113,8 @@ If you have success/failure using LLM Foundry on other systems, please let us kn | Device | Torch Version | Cuda Version | Status | | -------------- | ------------- | ------------ | ---------------------------- | -| A100-40GB/80GB | 2.3.0 | 12.1 | :white_check_mark: Supported | -| H100-80GB | 2.3.0 | 12.1 | :white_check_mark: Supported | +| A100-40GB/80GB | 2.3.1 | 12.1 | :white_check_mark: Supported | +| H100-80GB | 2.3.1 | 12.1 | :white_check_mark: Supported | ## MosaicML Docker Images We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories. @@ -122,15 +122,15 @@ We highly recommend using our prebuilt Docker images. You can find them here: ht The `mosaicml/pytorch` images are pinned to specific PyTorch and CUDA versions, and are stable and rarely updated. The `mosaicml/llm-foundry` images are built with new tags upon every commit to the `main` branch. -You can select a specific commit hash such as `mosaicml/llm-foundry:2.3.0_cu121_flash2-36ab1ba` or take the latest one using `mosaicml/llm-foundry:2.3.0_cu121_flash2-latest`. +You can select a specific commit hash such as `mosaicml/llm-foundry:2.3.1_cu121-36ab1ba` or take the latest one using `mosaicml/llm-foundry:2.3.1_cu121-latest`. **Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source. | Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? | | ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- | -| `mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04` | 2.3.0 | 12.1 (Infiniband) | No | -| `mosaicml/llm-foundry:2.3.0_cu121_flash2-latest` | 2.3.0 | 12.1 (Infiniband) | Yes | -| `mosaicml/llm-foundry:2.3.0_cu121_flash2_aws-latest` | 2.3.0 | 12.1 (EFA) | Yes | +| `mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04` | 2.3.1 | 12.1 (Infiniband) | No | +| `mosaicml/llm-foundry:2.3.1_cu121-latest` | 2.3.1 | 12.1 (Infiniband) | Yes | +| `mosaicml/llm-foundry:2.3.1_cu121_aws-latest` | 2.3.1 | 12.1 (EFA) | Yes | # Installation @@ -230,7 +230,7 @@ python data_prep/convert_dataset_hf.py \ # Train an MPT-125m model for 10 batches composer train/train.py \ train/yamls/pretrain/mpt-125m.yaml \ - data_local=my-copy-c4 \ + variables.data_local=my-copy-c4 \ train_loader.dataset.split=train_small \ eval_loader.dataset.split=val_small \ max_duration=10ba \ @@ -264,7 +264,7 @@ Note: the `composer` command used above to train the model refers to the [Compos If you have a write-enabled [HuggingFace auth token](https://huggingface.co/docs/hub/security-tokens), you can optionally upload your model to the Hub! Just export your token like this: ```bash -export HUGGING_FACE_HUB_TOKEN=your-auth-token +export HF_TOKEN=your-auth-token ``` and uncomment the line containing `--hf_repo_for_upload ...` in the above call to `inference/convert_composer_to_hf.py`. @@ -282,6 +282,8 @@ We provide two commands currently: Use `--help` on any of these commands for more information. +These commands can also help you understand what each registry is composed of, as each registry contains a docstring that will be printed out. The general concept is that each registry defines an interface, and components registered to that registry must implement that interface. If there is a part of the library that is not currently extendable, but you think it should be, please open an issue! + ## How to register There are a few ways to register a new component: @@ -289,8 +291,9 @@ There are a few ways to register a new component: ### Python entrypoints You can specify registered components via a Python entrypoint if you are building your own package with registered components. +This would be the expected usage if you are building a large extension to LLM Foundry, and going to be overriding many components. Note that things registered via entrypoints will override components registered directly in code. -For example, the following would register the `WandBLogger` class, under the key `wandb`, in the `llm_foundry.loggers` registry: +For example, the following would register the `MyLogger` class, under the key `my_logger`, in the `llm_foundry.loggers` registry: ```yaml @@ -306,10 +309,15 @@ dependencies = [ "llm-foundry", ] +# Note: Even though in python code, this would be llmfoundry.registry.loggers, +# when specified in the entry_points, it has to be "llmfoundry_loggers". That is, +# the segments of the name should be joined by an _ in the entry_points section. [project.entry-points."llmfoundry_loggers"] my_logger = "foundry_registry.loggers:MyLogger" ``` +If developing new components via entrypoints, it is important to note that Python entrypoints are global to the Python environment. This means that if you have multiple packages that register components with the same key, the last one installed will be the one used. This can be useful for overriding components in LLM Foundry, but can also lead to unexpected behavior if not careful. Additionally, if you change the pyproject.toml, you will need to reinstall the package for the changes to take effect. You can do this quickly by installing with `pip install -e . --no-deps` to avoid reinstalling dependencies. + ### Direct call to register You can also register a component directly in your code: @@ -359,6 +367,7 @@ code_paths: ... ``` +One of these would be the expected usage if you are building a small extension to LLM Foundry, only overriding a few components, and thus don't want to create an entire package. # Learn more about LLM Foundry! diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index c9666566bf..b851aaa559 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -9,6 +9,23 @@ import logging +try: + from flash_attn import flash_attn_func + del flash_attn_func +except ImportError as e: + if 'undefined symbol' in str(e): + raise ImportError( + 'The flash_attn package is not installed correctly. Usually this means that your runtime version' + + + ' of PyTorch is different from the version that flash_attn was installed with, which can occur when your' + + + ' workflow has resulted in PyTorch being reinstalled. This probably happened because you are using an old Docker image' + + + ' with the latest version of LLM Foundry. Check that the PyTorch version in your Docker image matches the PyTorch version' + + + ' in LLM Foundry setup.py and update accordingly. The latest Docker image can be found in the README.', + ) from e + from llmfoundry.utils.logging_utils import SpecificWarningFilter # Filter out Hugging Face warning for not using a pinned revision of the model @@ -33,6 +50,7 @@ tokenizers, utils, ) +from llmfoundry._version import __version__ from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric from llmfoundry.models.hf import ComposerHFCausalLM @@ -46,6 +64,7 @@ from llmfoundry.optim import DecoupledLionW __all__ = [ + '__version__', 'StreamingFinetuningDataset', 'StreamingTextDataset', 'InContextLearningDataset', @@ -70,5 +89,3 @@ 'tokenizers', 'utils', ] - -__version__ = '0.9.0.dev0' diff --git a/llmfoundry/_version.py b/llmfoundry/_version.py new file mode 100644 index 0000000000..4c11746b43 --- /dev/null +++ b/llmfoundry/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""The LLM Foundry Version.""" + +__version__ = '0.11.0.dev' diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 4712de5d5e..496e905e13 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -11,6 +11,7 @@ OptimizerMonitor, RuntimeEstimator, SpeedMonitor, + SystemMetricsMonitor, ) from llmfoundry.callbacks.async_eval_callback import AsyncEval @@ -35,6 +36,7 @@ from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config +callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor) callbacks.register('lr_monitor', func=LRMonitor) callbacks.register('memory_monitor', func=MemoryMonitor) callbacks.register('memory_snapshot', func=MemorySnapshot) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 646d86c8d3..1b3c31e861 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -557,7 +557,8 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: installation_path = i['path'] if not found_llm_foundry: - from llmfoundry import __version__ as latest_foundry_version + from llmfoundry._version import \ + __version__ as latest_foundry_version # If github integration is not found, foundry is likely installed # through the run command. In this case, we'll add the integration diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 961bf1cae1..449ab338bc 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -7,47 +7,228 @@ the future. """ +import copy import logging -from typing import Any, Dict +from typing import Any -from composer.core import State +from composer import DataSpec +from composer.core import State, Time, TimeUnit, ensure_time from composer.loggers import Logger from streaming import StreamingDataset +from streaming.base.util import clean_stale_shared_memory from torch.utils.data import DataLoader from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.utils.warnings import experimental_class +from llmfoundry.utils.exceptions import ( + BaseContextualError, + TrainDataLoaderLocation, +) log = logging.getLogger(__name__) __all__ = ['CurriculumLearning'] -@experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. + Example schedule: + [ + { + 'duration': tok, + 'train_loader': , # matches top level train_loader + }, + { + 'duration': tok, + 'train_loader': , + }, + { + 'duration': tok, + 'train_loader': , + ], + ] + Args: - train_config (Dict): The configuration of the dataset currently + train_config (dict): The configuration of the dataset currently being used. Note that this is the full train config and must - contain the 'train_loader' key. - dataset_index (int): The index of the dataset currently being used. + contain the 'train_loader', 'device_train_batch_size', and + 'tokenizer' keys. + schedule (list[dict[str, Any]]): The list of datamixes to use and their + durations. Duration units must match max_duration and be in terms of + a TimeUnit that is supported by Iteration. The duration values must + be positive. There must be at least one datamix in the schedule. The + first datamix in the schedule must match the train_loader in the + train_config. On resumption, previously trained on datamixes and + durations cannot be changed. The duration of the current datamix + must be greater than the saved timestamp. The dataset must be a + StreamingDataset. """ - def __init__(self, train_config: Dict, dataset_index: int): - self.dataset_index = dataset_index - self.saved_dataset_index = 0 - self.all_dataset_configs = [] - self.current_dataset_state = {} - # The current dataset config is resolved and passed in train.py - self.current_dataset_config = train_config['train_loader'] + def __init__( + self, + train_config: dict[str, Any], + schedule: list[dict[str, Any]], + ): + # Ensure all duration units are in epochs or tokens and values are positive + self._schedule = schedule + if len(self._schedule) == 0: + raise ValueError('The schedule must have at least one datamix.') + for index, datamix in enumerate(self._schedule): + self._validate_datamix(datamix) + + if ( + index == 0 and + train_config['train_loader'] != datamix['train_loader'] + ): + raise ValueError(( + 'The first datamix in the schedule must match the ' + 'train_loader in the train_config.' + )) + + self._schedule_index = 0 + self.device_train_batch_size = train_config['device_train_batch_size'] + self.tokenizer = None + + def init(self, state: State, logger: Logger): + del logger # unused + + if not hasattr(state.model, 'tokenizer'): + raise ValueError('state.model must have a tokenizer attribute.') + self.tokenizer = state.model.tokenizer def before_load(self, state: State, logger: Logger): - del logger + del logger # unused + + # Ensure all duration units are the same as max_duration + datamix_units = [datamix['duration'].unit for datamix in self._schedule] + assert state.max_duration is not None, 'max_duration should have beeen set.' + if any(state.max_duration.unit != unit for unit in datamix_units): + raise ValueError(( + f'All durations in the schedule must have the same units as ' + f'the max_duration. Expected {state.max_duration.unit}, but ' + f'got {datamix_units}.' + )) + + # Ensure schedule duration is equal to max_duration + schedule_duration = Time(0, state.max_duration.unit) + for datamix in self._schedule: + assert isinstance(datamix['duration'], Time) + schedule_duration += datamix['duration'] + if schedule_duration != state.max_duration: + raise ValueError(( + 'The sum of all durations in the schedule must be equal to the ' + 'max_duration.' + )) + + self._validate_dataloader(state.train_dataloader) + + def after_load(self, state: State, logger: Logger): + del logger # unused - # Save the current dataset state so we can restore it correctly - # if we are resuming with a new dataset. - train_loader = state.train_dataloader + self._validate_dataloader(state.train_dataloader) + + # If checkpoint was saved before iteration was incremented, we need to increment it now + duration = self._schedule[self._schedule_index]['duration'] + if (( + duration.unit == TimeUnit.TOKEN and + state.timestamp.token_in_iteration >= duration.value + ) or ( + duration.unit == TimeUnit.EPOCH and + state.timestamp.epoch_in_iteration >= duration.value + )): + log.warning(( + 'The CurriculumLearning callback has detected that the ' + 'previous run did not correctly increment the iteration.' + )) + self._schedule_index += 1 + state.timestamp = state.timestamp.to_next_iteration() + + def iteration_start(self, state: State, logger: Logger): + # Swap the dataset if starting a new iteration that's not the original datamix + if self._schedule_index > 0: + # TODO: trainer._train_data_spec should be updated whenever the dataloader is updated + # Dataloaders with the same prefix access the same shared memory + # which is stale + clean_stale_shared_memory() + datamix = copy.deepcopy(self._schedule[self._schedule_index]) + data_spec = self._build_train_loader( + train_loader_config=datamix['train_loader'], + logger=logger, + ) + state.set_dataloader( + dataloader=data_spec.dataloader, + dataloader_label='train', + ) + state.train_dataloader = state.dataloader + self._validate_dataloader(state.train_dataloader) + + # Set the length of the new iteration + state._iteration_length = self._schedule[self._schedule_index + ]['duration'] + + def iteration_end(self, state: State, logger: Logger): + del state, logger # unused + + self._schedule_index += 1 + + def state_dict(self): + return { + 'schedule': self._schedule, + 'schedule_index': self._schedule_index, + } + + def load_state_dict(self, state: dict[str, Any]): + self._schedule_index = state['schedule_index'] + + # Ensure that the schedule has not changed on previously trained datamixes + for idx in range(state['schedule_index']): + if self._schedule[idx] != state['schedule'][idx]: + raise ValueError(( + f'Previous datamixes must stay the same across ', + f'resumptions. Expected {state["schedule"][idx]} but got ', + f'{self._schedule[idx]}', + )) + + # Ensure that the datamix has not changed on the current datamix + current_loader = self._schedule[self._schedule_index]['train_loader'] + saved_loader = state['schedule'][self._schedule_index]['train_loader'] + if current_loader != saved_loader: + raise ValueError(( + f'The current datamix must stay the same across resumptions. ', + f'Expected {saved_loader} but got {current_loader}', + )) + + # Ensure that the current datamix duration is in the correct units + duration = self._schedule[self._schedule_index]['duration'] + if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: + raise ValueError(( + f'Duration must be in terms of tokens or epochs, but got ', + f'{duration.unit}.', + )) + + def _build_train_loader( + self, + train_loader_config: dict[str, Any], + logger: Logger, + ) -> DataSpec: + from llmfoundry.data.dataloader import build_dataloader + + # Copied from scripts/train/train.py + log.info( + f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}', + ) + assert self.tokenizer is not None + try: + return build_dataloader( + train_loader_config, + self.tokenizer, + self.device_train_batch_size, + ) + except BaseContextualError as e: + e.location = TrainDataLoaderLocation + raise e + + def _validate_dataloader(self, train_loader: Any): # Check if we are using a DataLoader and StreamingDataset if not isinstance(train_loader, DataLoader): raise ValueError( @@ -61,54 +242,23 @@ def before_load(self, state: State, logger: Logger): f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}', ) - assert isinstance(dataset, StreamingDataset) - # Save the current dataset state so we can restore it if needed. - self.current_dataset_state = dataset.state_dict( # type: ignore - num_samples=0, from_beginning=False) - def after_load(self, state: State, logger: Logger): - del logger - - # As saved_dataset_index is loaded from state_dict, this only runs when - # a user explicitly increments the dataset_index and not on any other - # resumption, including autoresume. - train_loader = state._train_dataloader - assert isinstance( - train_loader, - DataLoader, - ), 'CurriculumLearning callback requires a DataLoader.' - dataset = train_loader.dataset - assert isinstance( - dataset, - StreamingDataset, - ), 'CurriculumLearning callback requires a StreamingDataset.' - if self.saved_dataset_index < self.dataset_index: - # Ignore the dataset state that was read in from the checkpoint, and - # replace with the new dataset state. This preserves resumption info. - if self.current_dataset_state['epoch'] < 0: - # Make sure the epoch in the loaded state dict is not negative. - # Since `__iter__` has not yet been called on the dataset, the - # epoch index in the dataset will still be -1. We need to ensure - # that we set the epoch correctly to 0 in this case. - self.current_dataset_state['epoch'] = 0 - dataset.load_state_dict( # type: ignore - self.current_dataset_state) - # Start a new epoch since we are using a new dataset. - # This will also reset the sample_in_epoch written to checkpoint, - # making sure that subsequent resumptions proceed correctly. - state.timestamp = state.timestamp.to_next_epoch() - # Append the new dataset config to the list of all dataset configs. - self.all_dataset_configs.append(self.current_dataset_config) - elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: - # Make sure to track our current dataset config if we are just starting training. - self.all_dataset_configs.append(self.current_dataset_config) - - def state_dict(self): - return { - 'dataset_index': self.dataset_index, - 'all_dataset_configs': self.all_dataset_configs, - } + def _validate_datamix(self, datamix: dict[str, Any]): + if 'duration' not in datamix: + raise ValueError('Each datamix must have a duration.') + datamix['duration'] = ensure_time( + datamix['duration'], + TimeUnit.EPOCH, + ) + if datamix['duration'].value <= 0: + raise ValueError('The duration must be positive.') + if ( + datamix['duration'].unit != TimeUnit.EPOCH and + datamix['duration'].unit != TimeUnit.TOKEN + ): + raise ValueError( + 'Schedules can only be defined in terms of epochs or tokens.', + ) - def load_state_dict(self, state: Dict[str, Any]): - self.saved_dataset_index = state.get('dataset_index', 0) - self.all_dataset_configs = state.get('all_dataset_configs', []) + if 'train_loader' not in datamix: + raise ValueError('Each datamix must have a train_loader.') diff --git a/llmfoundry/callbacks/eval_output_logging_callback.py b/llmfoundry/callbacks/eval_output_logging_callback.py index edcd6ed336..b84ea063d1 100644 --- a/llmfoundry/callbacks/eval_output_logging_callback.py +++ b/llmfoundry/callbacks/eval_output_logging_callback.py @@ -5,11 +5,12 @@ import warnings from copy import deepcopy -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch from composer.core import Callback, State from composer.loggers import ConsoleLogger, Logger +from composer.models import HuggingFaceModel from composer.utils.dist import all_gather_object @@ -24,51 +25,85 @@ class EvalOutputLogging(Callback): into `batch_keys_to_log`. It will do so after every eval batch. """ - def __init__(self, log_tokens: bool = False, *args: Any, **kwargs: Any): + def __init__( + self, + log_tokens: bool = False, + log_output_text: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ): super().__init__(self, *args, **kwargs) self.log_tokens = log_tokens self.columns = None self.name = None self.rows = [] + self.log_output_text = log_output_text + + def init(self, state: State, logger: Logger) -> None: + if self.log_output_text is False: + return + + has_output_text = ( + isinstance(state.model, HuggingFaceModel) + and state.dataloader is not None + and hasattr( + state.dataloader.dataset, # pyright: ignore[reportGeneralTypeIssues] + 'tokenizer', + ) + ) + if self.log_output_text is True and has_output_text is False: + raise ValueError( + '`log_output_text=True` is only supported for HuggingFace models and datasets with tokenizers.', + ) + elif self.log_output_text is None: + self.log_output_text = has_output_text def eval_batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.batch, Dict): warnings.warn( - f'''EvalOutputLogging only supports batches that are dictionary. \ + f"""EvalOutputLogging only supports batches that are dictionary. \ Found batch for type {type(state.batch)}. \ - Not logging eval outputs.''', + Not logging eval outputs.""", ) return assert state.outputs is not None assert state.metric_outputs is not None - logging_dict: Dict[str, Union[List[Any], torch.Tensor, - Sequence[torch.Tensor]]] = deepcopy( - state.metric_outputs, - ) - - # If batch mode is not generate, outputs will be logits - if state.batch['mode'] == 'generate': + logging_dict: Dict[str, + Union[List[Any], torch.Tensor, + Sequence[torch.Tensor]], + ] = deepcopy( + state.metric_outputs, + ) + + if state.batch.get('mode') == 'generate': # Outputs are already detokenized logging_dict['outputs'] = state.outputs + elif self.log_output_text and isinstance(state.outputs, torch.Tensor): + # If batch mode is not generate, outputs will be logits + logging_dict['outputs'] = state.outputs.argmax(dim=-1) input_ids = state.batch['input_ids'] logged_input = [] assert state.dataloader is not None + dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] + tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] + pad_token_id = getattr( + dataset, + 'pad_tok_id', + dataset.tokenizer.pad_token_id, + ) # Depad and decode input_ids for input_list in input_ids.tolist(): - dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] - depadded_input = [ - tok for tok in input_list if tok != dataset.pad_tok_id - ] - logged_input.append(dataset.tokenizer.decode(depadded_input)) + depadded_input = [tok for tok in input_list if tok != pad_token_id] + logged_input.append(tokenizer.decode(depadded_input)) logging_dict['input'] = logged_input # Log token indices if toggled if self.log_tokens: logging_dict['input_tokens'] = input_ids.tolist() - if not state.batch['mode'] == 'generate': + if not state.batch.get('mode') == 'generate': if isinstance(state.outputs, torch.Tensor): # pyright logging_dict['label_tokens'] = state.outputs.tolist() @@ -85,15 +120,9 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: for key, value in logging_dict.items(): # All types in list are the same if isinstance(value[0], torch.Tensor): - logging_dict[key] = [ - state.dataloader.dataset. # pyright: ignore[reportGeneralTypeIssues] - tokenizer.decode( # pyright: ignore[reportGeneralTypeIssues] - t, - ) for t in value - ] + logging_dict[key] = [tokenizer.decode(t) for t in value] elif isinstance(value[0], list): if isinstance(value[0][0], torch.Tensor): - tokenizer = state.dataloader.dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] logging_dict[key] = [[ tokenizer.decode(choice) for choice in t ] for t in value] diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 28b33b43d8..79dc73de98 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,8 +17,7 @@ import numpy as np import torch import torch.nn as nn -from composer.core import Callback, Event, State, Time, TimeUnit -from composer.core.state import fsdp_state_dict_type_context +from composer.core import Callback, Event, Precision, State, Time, TimeUnit from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -29,14 +28,29 @@ ) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information -from packaging import version -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import ( + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.models.utils import init_empty_weights from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility +try: + import transformer_engine.pytorch as te + is_te_imported = True +except ModuleNotFoundError: + is_te_imported = False + log = logging.getLogger(__name__) __all__ = ['HuggingFaceCheckpointer'] @@ -169,6 +183,7 @@ def __init__( 'bfloat16': torch.bfloat16, }[precision] self.flatten_imports = flatten_imports + self.using_peft = False # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name @@ -202,6 +217,14 @@ def __init__( ) self.mlflow_logging_config = mlflow_logging_config + if 'metadata' in self.mlflow_logging_config: + self.pretrained_model_name = self.mlflow_logging_config[ + 'metadata'].get( + 'pretrained_model_name', + None, + ) + else: + self.pretrained_model_name = None self.huggingface_folder_name_fstr = os.path.join( 'huggingface', @@ -264,6 +287,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) + + # Check if the model is using PEFT + if state.is_model_ddp: + composer_model = state.model.module + elif isinstance(state.model.model, FSDP): + composer_model = state.model + else: + composer_model = state.model + self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. timeout = 3600 @@ -332,6 +364,54 @@ def transform_model_and_tokenizer( """ return model, tokenizer + def transform_config( + self, + original_config: PretrainedConfig, + ) -> PretrainedConfig: + """Transform the model config before saving. + + Args: + original_config (Any): The original model config. + + Returns: + The transformed model config. + """ + copied_config = copy.deepcopy(original_config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}): + copied_config.ffn_config['moe_world_size'] = 1 + return copied_config + + def pre_register_edit(self, local_save_path: str): + """Edit the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + local_save_path (str): The path to the model to be transformed. + """ + pass + + def transform_model_pre_registration( + self, + model: PreTrainedModel, + ) -> PreTrainedModel: + """Transform the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -358,118 +438,90 @@ def _save_checkpoint(self, state: State, logger: Logger): temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: - composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model state_dict_model = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): - composer_model = state.model original_model: PreTrainedModel = state.model.model.module state_dict_model = state.model.model original_tokenizer = state.model.tokenizer else: - composer_model = state.model original_model: PreTrainedModel = state.model.model state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - if version.parse(torch.__version__) > version.parse('2.2.9'): - from torch.distributed._tensor import DTensor - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - ) - cpu_offload = True - - # Add a dtensor->cpu tensor hook to avoid CUDA OOM - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - hooks = [] - for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append( - module. - _register_state_dict_hook(dtensor_to_tensor_hook), - ) - - state_dict = get_model_state_dict( - state_dict_model, - options=StateDictOptions( - full_state_dict=True, - cpu_offload=cpu_offload, - ), - ) - for hook in hooks: - hook.remove() - else: - state_dict_context = fsdp_state_dict_type_context( - original_model, - state_dict_type='full', - ) if ((not state.is_model_ddp) and - isinstance(state_dict_model, - FSDP)) else contextlib.nullcontext() - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # Convert the state dict to the requested precis - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) + cpu_offload = True + + # Add hook to move tensors to cpu to avoid CUDA OOM + def tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + # Offload any DTensors to CPU + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + else: + state_dict[fqn] = None + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) + del tensor + if dist.get_global_rank() != 0: + state_dict = {} + return state_dict + + hooks = [] + for _, module in state_dict_model.named_modules(): + hooks.append(module._register_state_dict_hook(tensor_hook),) + + state_dict = get_model_state_dict( + state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + ), + ) + for hook in hooks: + hook.remove() new_model_instance = None # Need this for pyright because variable could be unbound if dist.get_global_rank() == 0: log.debug('Saving Hugging Face checkpoint in global rank 0') - # Edit HF config before building 2nd model copy - copied_config = copy.deepcopy(original_model.config) - if copied_config.model_type == 'mpt': - copied_config.attn_config['attn_impl'] = 'torch' - copied_config.init_device = 'cpu' - if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}): - copied_config.ffn_config['moe_world_size'] = 1 + # Transform HF config before building 2nd model copy + new_config = self.transform_config( + original_config=original_model.config, + ) log.debug(f'Creating new model instance') - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(copied_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter], - ) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): - new_model_instance = type(original_model)(copied_config) + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + if self.using_peft: + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(new_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter], + ) + else: + new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( **original_model.generation_config.to_dict(), ) @@ -485,14 +537,34 @@ def dtensor_to_tensor_hook( original_tokenizer, ) + # Ensure that the pretrained model name is correctly set on the saved HF checkpoint. + if self.pretrained_model_name is not None: + new_model_instance.name_or_path = self.pretrained_model_name + if self.using_peft: + new_model_instance.base_model.name_or_path = self.pretrained_model_name + for k in new_model_instance.peft_config.keys(): + new_model_instance.peft_config[ + k + ].base_model_name_or_path = self.pretrained_model_name + log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) + # This context manager casts the TE extra state in io.BytesIO format to tensor format + # Needed for proper hf ckpt saving. + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: - assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + assert isinstance( + original_tokenizer, + PreTrainedTokenizerBase, + ) original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': + if new_model_instance.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility( temp_save_dir, @@ -519,6 +591,11 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): + + new_model_instance = self.transform_model_pre_registration( + new_model_instance, + ) + components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer @@ -538,7 +615,7 @@ def dtensor_to_tensor_hook( model_saving_kwargs: Dict[str, Any] = { 'path': local_save_path, } - if composer_model.using_peft: + if self.using_peft: model_saving_kwargs['flavor'] = 'peft' model_saving_kwargs['save_pretrained_dir' ] = temp_save_dir @@ -549,15 +626,23 @@ def dtensor_to_tensor_hook( model_saving_kwargs['transformers_model'] = components model_saving_kwargs.update(self.mlflow_logging_config) - mlflow_logger.save_model(**model_saving_kwargs) + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] + mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. license_filename = _maybe_get_license_filename( local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', - None, - ), + self.pretrained_model_name, ) if license_filename is not None: mlflow_logger._mlflow_client.log_artifact( @@ -565,6 +650,8 @@ def dtensor_to_tensor_hook( os.path.join(local_save_path, license_filename), ) + self.pre_register_edit(local_save_path,) + # Spawn a new process to register the model. process = SpawnProcess( target=_register_model_with_run_id_multiprocess, diff --git a/llmfoundry/callbacks/loss_perp_v_len_callback.py b/llmfoundry/callbacks/loss_perp_v_len_callback.py index aa9519c255..ebb9583224 100644 --- a/llmfoundry/callbacks/loss_perp_v_len_callback.py +++ b/llmfoundry/callbacks/loss_perp_v_len_callback.py @@ -262,9 +262,15 @@ def update( self.sum_length += valid_labels_mask.sum(dim=0) if sequence_id is not None: - seq_id_expanded = torch.nn.functional.one_hot( - sequence_id, - ).transpose(-1, -2) + seq_id_mask = (sequence_id != -1) + sequence_id = torch.where(seq_id_mask, sequence_id, 0) + seq_id_expanded = torch.nn.functional.one_hot(sequence_id,) + seq_id_expanded = torch.where( + torch.unsqueeze(seq_id_mask, dim=-1), + seq_id_expanded, + 0, + ) + seq_id_expanded = seq_id_expanded.transpose(-1, -2) seq_lens = seq_id_expanded.sum(dim=-1) max_num_seq = seq_lens.shape[1] seq_tok_ids = torch.arange(seq_len, device=sequence_id.device)[ diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index eb8051240d..4c791f37d7 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -8,17 +8,12 @@ from typing import Optional from composer import Callback, Logger, State -from composer.loggers import MosaicMLLogger - -from llmfoundry.utils.exceptions import RunTimeoutError log = logging.getLogger(__name__) -def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None): +def _timeout(timeout: int): log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) - if mosaicml_logger is not None: - mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout)) os.kill(os.getpid(), signal.SIGINT) @@ -29,14 +24,8 @@ def __init__( timeout: int = 1800, ): self.timeout = timeout - self.mosaicml_logger: Optional[MosaicMLLogger] = None self.timer: Optional[threading.Timer] = None - def init(self, state: State, logger: Logger): - for callback in state.callbacks: - if isinstance(callback, MosaicMLLogger): - self.mosaicml_logger = callback - def _reset(self): if self.timer is not None: self.timer.cancel() @@ -47,7 +36,7 @@ def _timeout(self): self.timer = threading.Timer( self.timeout, _timeout, - [self.timeout, self.mosaicml_logger], + [self.timeout], ) self.timer.daemon = True self.timer.start() diff --git a/llmfoundry/cli/cli.py b/llmfoundry/cli/cli.py index 25c1a6d230..6c4a2d12c4 100644 --- a/llmfoundry/cli/cli.py +++ b/llmfoundry/cli/cli.py @@ -1,12 +1,53 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import typer +from typing import Annotated, Optional -from llmfoundry.cli import registry_cli +from typer import Argument, Typer -app = typer.Typer(pretty_exceptions_show_locals=False) +from llmfoundry.cli import ( + data_prep_cli, + registry_cli, +) +from llmfoundry.command_utils import ( + eval_from_yaml, + train_from_yaml, +) + +app = Typer(pretty_exceptions_show_locals=False) app.add_typer(registry_cli.app, name='registry') +app.add_typer(data_prep_cli.app, name='data_prep') + + +@app.command(name='train') +def train( + yaml_path: Annotated[str, + Argument( + ..., + help='Path to the YAML configuration file', + )], + args_list: Annotated[ + Optional[list[str]], + Argument(help='Additional command line arguments')] = None, +): + """Run the training with optional overrides from CLI.""" + train_from_yaml(yaml_path, args_list) + + +@app.command(name='eval') +def eval( + yaml_path: Annotated[str, + Argument( + ..., + help='Path to the YAML configuration file', + )], + args_list: Annotated[ + Optional[list[str]], + Argument(help='Additional command line arguments')] = None, +): + """Run the eval with optional overrides from CLI.""" + eval_from_yaml(yaml_path, args_list) + if __name__ == '__main__': app() diff --git a/llmfoundry/cli/data_prep_cli.py b/llmfoundry/cli/data_prep_cli.py new file mode 100644 index 0000000000..130e0a6585 --- /dev/null +++ b/llmfoundry/cli/data_prep_cli.py @@ -0,0 +1,268 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Annotated, Optional + +import psutil +from typer import Option, Typer + +from llmfoundry.command_utils import ( + convert_dataset_hf_from_args, + convert_dataset_json_from_args, + convert_delta_to_json_from_args, + convert_finetuning_dataset_from_args, + convert_text_to_mds_from_args, +) + +app = Typer(pretty_exceptions_show_locals=False) + + +@app.command(name='convert_dataset_hf') +def convert_dataset_hf( + dataset: Annotated[str, Option(..., help='Name of the dataset')], + out_root: Annotated[str, Option(..., help='Output root directory')], + data_subset: Annotated[ + Optional[str], + Option(help='Subset of the dataset (e.g., "all" or "en")'), + ] = None, + splits: Annotated[str, + Option(help='Comma-separated list of dataset splits',), + ] = 'train, train_small, val, val_small, val_xsmall', + compression: Annotated[Optional[str], + Option(help='Compression type')] = None, + concat_tokens: Annotated[ + Optional[int], + Option(help='Concatenate tokens up to this many tokens')] = None, + tokenizer: Annotated[Optional[str], + Option(help='Tokenizer name')] = None, + tokenizer_kwargs: Annotated[ + Optional[str], + Option(help='Tokenizer keyword arguments in JSON format')] = None, + bos_text: Annotated[Optional[str], Option(help='BOS text')] = None, + eos_text: Annotated[Optional[str], Option(help='EOS text')] = None, + no_wrap: Annotated[ + bool, + Option(help='Do not wrap text across max_length boundaries'), + ] = False, + num_workers: Annotated[Optional[int], + Option(help='Number of workers')] = None, +): + """Converts dataset from HuggingFace into JSON files.""" + # Convert comma-separated splits into a list + splits_list = splits.split(',') if splits else [] + convert_dataset_hf_from_args( + dataset=dataset, + data_subset=data_subset, + splits=splits_list, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) + + +@app.command(name='convert_dataset_json') +def convert_dataset_json( + path: Annotated[str, Option(..., help='Path to the input data file')], + out_root: Annotated[str, Option(..., help='Output root directory')], + concat_tokens: Annotated[ + int, + Option( + ..., + help='Convert text to tokens and concatenate up to this many tokens', + )], + tokenizer: Annotated[str, Option(..., help='Tokenizer name')], + compression: Annotated[Optional[str], + Option(help='Compression type, if any')] = 'zstd', + split: Annotated[str, Option(help='Dataset split to process')] = 'train', + bos_text: Annotated[ + Optional[str], + Option(help='Text to insert at the beginning of each sequence')] = None, + eos_text: Annotated[ + Optional[str], + Option(help='Text to insert at the end of each sequence')] = None, + no_wrap: Annotated[ + bool, + Option(help='Do not wrap text across max_length boundaries')] = False, + num_workers: Annotated[ + Optional[int], + Option(help='Number of workers for data loading')] = None, +): + """Convert a dataset from JSON to MDS streaming format.""" + convert_dataset_json_from_args( + path=path, + split=split, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) + + +@app.command(name='convert_finetuning_dataset') +def convert_finetuning_dataset_cli( + dataset: Annotated[ + str, + Option( + ..., + help= + 'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`).', + )], + data_subset: Annotated[ + Optional[str], + Option(help='(Optional) subset of data to use.',)] = None, + splits: Annotated[str, + Option(help='Comma-separated list of dataset splits'), + ] = 'train,validation', + preprocessor: Annotated[ + Optional[str], + Option( + help= + 'Name or import path of function used to preprocess (reformat) the dataset.', + )] = None, + data_files: Annotated[ + str, Option(help='Data file for each split. Comma-separated.')] = '', + skip_preprocessing: Annotated[ + bool, Option(help='Whether to skip preprocessing.')] = False, + out_root: Annotated[ + str, + Option( + ..., + help= + 'Root path of output directory where MDS shards will be stored. Can be a remote URI.', + )] = '', + local: Annotated[ + Optional[str], + Option( + help= + '(Optional) root path of local directory if you want to keep a local copy when out_root is remote.', + )] = None, + compression: Annotated[ + Optional[str], + Option(help='(Optional) name of compression algorithm to use.')] = None, + num_workers: Annotated[Optional[int], + Option(help='Number of workers.')] = None, + tokenizer: Annotated[Optional[str], + Option(help='Tokenizer used for processing.')] = None, + tokenizer_kwargs: Annotated[ + Optional[str], + Option( + help= + 'Keyword arguments for tokenizer initialization in JSON format.', + )] = None, + max_seq_len: Annotated[int, Option(help='Maximum sequence length.')] = 2048, + target_prompts: Annotated[ + str, + Option(help='Policy for when to use prompts as training targets.'), + ] = 'none', + target_responses: Annotated[ + str, + Option(help='Policy for which responses to treat as training targets.'), + ] = 'last', + encoder_decoder: Annotated[ + bool, + Option( + help= + 'Set if the data are intended to be used to train an encoder-decoder model.', + )] = False, +): + """Convert a Finetuning Dataset to MDS streaming format.""" + # Convert comma-separated args + splits_list = splits.split(',') if splits else [] + data_files_list = data_files.split(',') if data_files else [] + convert_finetuning_dataset_from_args( + dataset=dataset, + data_subset=data_subset, + splits=splits_list, + preprocessor=preprocessor, + data_files=data_files_list, + skip_preprocessing=skip_preprocessing, + out_root=out_root, + local=local, + compression=compression, + num_workers=num_workers, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + max_seq_len=max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + encoder_decoder=encoder_decoder, + ) + + +@app.command(name='convert_text_to_mds') +def convert_text_to_mds( + output_folder: Annotated[str, Option(..., help='The folder to write output to')], + input_folder: Annotated[str, Option(..., help='The folder with text files to convert to MDS')], + concat_tokens: Annotated[int, Option(..., help='Convert text to tokens and concatenate up to this many tokens')], + tokenizer: Annotated[str, Option(..., help='The name of the tokenizer to use')], + bos_text: Annotated[Optional[str], Option(help='The text to prepend to each example to separate concatenated examples')] = None, + eos_text: Annotated[Optional[str], Option(help='The text to append to each example to separate concatenated examples')] = None, + compression: Annotated[str, Option(help='The compression algorithm to use for MDS writing')] = 'zstd', + use_tokenizer_eos: Annotated[bool, Option(help='Use the EOS text from the tokenizer')] = False, + no_wrap: Annotated[bool, Option(help='Whether to let text examples wrap across multiple training examples')] = False, + processes: Annotated[int, Option( + help='The number of processes to use to download and convert the dataset', + )] = min(max(psutil.cpu_count() - 2, 1), 32), # type: ignore + reprocess: Annotated[bool, Option( + help= + 'If true, reprocess the input_folder to MDS format. Otherwise, only reprocess upon changes to the input folder or dataset creation parameters.', + )] = False, + trust_remote_code: Annotated[bool, Option( + help='If true, allows custom code to be executed to load the tokenizer', + )] = False, + logging_level: Annotated[str, Option( + help='Logging level for the script. Default is INFO.', + )] = 'INFO', + +): + """Convert text files to MDS streaming format.""" + convert_text_to_mds_from_args( + output_folder=output_folder, + input_folder=input_folder, + compression=compression, + concat_tokens=concat_tokens, + tokenizer_name=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + use_tokenizer_eos=use_tokenizer_eos, + no_wrap=no_wrap, + processes=processes, + reprocess=reprocess, + trust_remote_code=trust_remote_code, + logging_level=logging_level, + ) + + +@app.command(name='convert_delta_to_json') +def convert_delta_to_json_cli( + delta_table_name: Annotated[str, Option(..., help='UC table ..')], + json_output_folder: Annotated[str, Option(..., help='Local path to save the converted json')], + http_path: Annotated[Optional[str], Option(help='If set, dbsql method is used')] = None, + batch_size: Annotated[int, Option(help='Row chunks to transmit a time to avoid OOM')] = 1 << 30, + processes: Annotated[int, Option(help='Number of processes allowed to use')] = os.cpu_count(), # type: ignore + cluster_id: Annotated[Optional[str], Option(help='Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.')] = None, + use_serverless: Annotated[bool, Option(help='Use serverless or not. Make sure the workspace is entitled with serverless')] = False, + json_output_filename: Annotated[str, Option(help='The name of the combined final jsonl that combines all partitioned jsonl')] = 'train-00000-of-00001.jsonl', +): + """Convert a Delta table into JSON files.""" + convert_delta_to_json_from_args( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_output_filename, + ) diff --git a/llmfoundry/cli/registry_cli.py b/llmfoundry/cli/registry_cli.py index 38ada51fd9..db090cd3aa 100644 --- a/llmfoundry/cli/registry_cli.py +++ b/llmfoundry/cli/registry_cli.py @@ -3,15 +3,15 @@ from typing import Optional -import typer from rich.console import Console from rich.table import Table +from typer import Typer from llmfoundry import registry from llmfoundry.utils.registry_utils import TypedRegistry console = Console() -app = typer.Typer(pretty_exceptions_show_locals=False) +app = Typer(pretty_exceptions_show_locals=False) def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]: diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py new file mode 100644 index 0000000000..0226c4f408 --- /dev/null +++ b/llmfoundry/command_utils/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from llmfoundry.command_utils.data_prep.convert_dataset_hf import ( + convert_dataset_hf, + convert_dataset_hf_from_args, +) +from llmfoundry.command_utils.data_prep.convert_dataset_json import ( + convert_dataset_json, + convert_dataset_json_from_args, +) +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + convert_delta_to_json_from_args, + fetch_DT, +) +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import ( + convert_finetuning_dataset, + convert_finetuning_dataset_from_args, +) +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( + convert_text_to_mds, + convert_text_to_mds_from_args, +) +from llmfoundry.command_utils.eval import ( + eval_from_yaml, + evaluate, +) +from llmfoundry.command_utils.train import ( + TRAIN_CONFIG_KEYS, + TrainConfig, + train, + train_from_yaml, + validate_config, +) + +__all__ = [ + 'train', + 'train_from_yaml', + 'TrainConfig', + 'TRAIN_CONFIG_KEYS', + 'validate_config', + 'evaluate', + 'eval_from_yaml', + 'convert_dataset_hf', + 'convert_dataset_hf_from_args', + 'convert_dataset_json', + 'convert_dataset_json_from_args', + 'convert_finetuning_dataset_from_args', + 'convert_finetuning_dataset', + 'convert_text_to_mds', + 'convert_text_to_mds_from_args', + 'convert_delta_to_json_from_args', + 'fetch_DT', +] diff --git a/llmfoundry/command_utils/data_prep/__init__.py b/llmfoundry/command_utils/data_prep/__init__.py new file mode 100644 index 0000000000..80950cb7b4 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_hf.py b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py new file mode 100644 index 0000000000..f9bbe6b0cf --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py @@ -0,0 +1,489 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming dataset conversion scripts for C4 and The Pile.""" +import json +import os +import platform +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, Optional, Union + +import datasets as hf_datasets +import psutil +import torch +from numpy.typing import NDArray +from streaming import MDSWriter +from torch.utils.data import DataLoader, Dataset, IterableDataset +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +from llmfoundry.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.utils.builders import build_tokenizer + + +class ConcatMode(Enum): + NO_CONCAT = 'NO_CONCAT' + CONCAT_TOKENS = 'CONCAT_TOKENS' + + +@dataclass +class DataSplitConstants: + hf_split: str + folder_split: str + raw_samples: Optional[int] + truncated_samples: Union[int, None] + + +@dataclass +class DatasetConstants: + chars_per_sample: int + chars_per_token: int + splits: Dict[str, DataSplitConstants] = field(default_factory=dict) + + def __iter__(self): + for v in self.splits.values(): + yield v + + +class TrainSmallConstants(DataSplitConstants): + + def __init__( + self, + hf_split: str = 'train', + folder_split: str = 'train_small', + raw_samples: int = 100000, + truncated_samples: int = 100000, + ): + super().__init__(hf_split, folder_split, raw_samples, truncated_samples) + + +class ValSmallConstants(DataSplitConstants): + + def __init__( + self, + hf_split: str = 'validation', + folder_split: str = 'val_small', + raw_samples: int = 10000, + truncated_samples: int = 10000, + ): + super().__init__(hf_split, folder_split, raw_samples, truncated_samples) + + +class ValXSmallConstants(DataSplitConstants): + + def __init__( + self, + hf_split: str = 'validation', + folder_split: str = 'val_xsmall', + raw_samples: int = 3000, + truncated_samples: int = 3000, + ): + super().__init__(hf_split, folder_split, raw_samples, truncated_samples) + + +pileconstants = DatasetConstants( + chars_per_sample=6212, # Computed over validation set + chars_per_token=4, # OpenAI estimate +) +pileconstants.splits['train'] = DataSplitConstants( + hf_split='train', + folder_split='train', + raw_samples=210607728, + truncated_samples=None, +) +pileconstants.splits['train_small'] = DataSplitConstants( + hf_split='train', + folder_split='train_small', + raw_samples=100000, + truncated_samples=100000, +) +pileconstants.splits['val'] = DataSplitConstants( + hf_split='validation', + folder_split='val', + raw_samples=214670, + truncated_samples=None, +) +pileconstants.splits['val_small'] = DataSplitConstants( + hf_split='validation', + folder_split='val_small', + raw_samples=10000, + truncated_samples=10000, +) +pileconstants.splits['val_xsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xsmall', + raw_samples=3000, + truncated_samples=3000, +) + +c4constants = DatasetConstants( + chars_per_sample=2163, # Computed over validation set + chars_per_token=4, # OpenAI estimate +) +c4constants.splits['train'] = DataSplitConstants( + hf_split='train', + folder_split='train', + raw_samples=364868892, + truncated_samples=None, +) +c4constants.splits['train_small'] = DataSplitConstants( + hf_split='train', + folder_split='train_small', + raw_samples=100000, + truncated_samples=100000, +) +c4constants.splits['val'] = DataSplitConstants( + hf_split='validation', + folder_split='val', + raw_samples=364608, + truncated_samples=None, +) +c4constants.splits['val_small'] = DataSplitConstants( + hf_split='validation', + folder_split='val_small', + raw_samples=10000, + truncated_samples=10000, +) +c4constants.splits['val_xsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xsmall', + raw_samples=3000, + truncated_samples=3000, +) +c4constants.splits['val_xxsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xxsmall', + raw_samples=100, + truncated_samples=100, +) + +CONSTS = {'c4': c4constants, 'the_pile': pileconstants} + + +def build_hf_dataset( + dataset_name: str, + split: str, + mode: ConcatMode, + max_length: Optional[int] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + tokenizer: PreTrainedTokenizerBase = None, + data_subset: Union[str, None] = None, +) -> IterableDataset: + """Build an IterableDataset over the HF C4 or pile source data. + + Args: + dataset_name (str): Dataset name + split (str): Split name. + mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS + max_length (int): The length of concatenated tokens + bos_text (str): text to insert at the beginning of each sequence + eos_text (str): text to insert at the end of each sequence + no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries + tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use + data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. + Typically "all" (The Pile) or "en" (c4). + + Returns: + An IterableDataset. + """ + hf_dataset = hf_datasets.load_dataset( + path=dataset_name, + name=data_subset, + split=split, + streaming=True, + ) + if mode == ConcatMode.NO_CONCAT: + dataset = NoConcatDataset(hf_dataset) + else: + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) + if max_length is None: + raise ValueError(f'max_length must be set.') + if bos_text + eos_text == '': + test_tokens = tokenizer('test') + if test_tokens['input_ids'][ + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: + tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' + tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' + tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' + tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' + tok_error_msg += '--bos_text=<|endoftext|>.' + raise ValueError(tok_error_msg) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) + return dataset + + +def _est_progress_denominator( + total_samples: int, + chars_per_sample: int, + chars_per_token: int, + mode: ConcatMode, + max_length: int, +): + est_tokens_per_sample = chars_per_sample // chars_per_token + if mode == ConcatMode.NO_CONCAT: + return total_samples + elif mode == ConcatMode.CONCAT_TOKENS: + return total_samples * est_tokens_per_sample // max_length + + +def build_dataloader( + dataset: Dataset, + batch_size: int, + num_workers: Optional[int], +) -> DataLoader: + if num_workers is None: + # Multiple workers is only supported on linux machines + if 'linux' or 'macos' in platform.platform().lower(): + num_workers = max(1, psutil.cpu_count()) + else: + num_workers = 0 + + # If using multiple workers, configure each worker to prefetch as many samples as it can, up to + # the aggregate device batch size + # If not using workers, the torch DataLoader expects the default value for prefetch_factor, + # which non-intuitively must be 2. + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 + + return DataLoader( + dataset=dataset, + sampler=None, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + +def generate_samples( + loader: DataLoader, + truncate_num_samples: Optional[int] = None, +) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]: + """Generator over samples of a dataloader. + + Args: + loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} + truncate_num_samples (Optional[int]): An optional # of samples to stop at. + + Yields: + Sample dicts. + """ + n_samples = 0 + for batch in loader: + keys = list(batch.keys()) + current_bs = len(batch[keys[0]]) + for idx in range(current_bs): + if truncate_num_samples is not None and n_samples == truncate_num_samples: + return + n_samples += 1 + yield { + k: + v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx] + for k, v in batch.items() + } + + +def convert_dataset_hf( + dataset: str, + data_subset: Optional[str], + splits: list[str], + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: dict[str, Any], + bos_text: str, + eos_text: str, + no_wrap: bool, + num_workers: Optional[int], +) -> None: + """Converts HuggingFace datasets to MDS format. + + Args: + dataset (str): Name of the dataset + data_subset (Optional[str]): Subset of the dataset (e.g., "all" or "en") + splits (list[str]): Comma-separated list of dataset splits + out_root (str): Output root directory + compression (Optional[str]): Compression type + concat_tokens (Optional[int]): Concatenate tokens up to this many tokens + tokenizer (Optional[str]): Tokenizer name + tokenizer_kwargs (dict[str, Any]): Tokenizer keyword arguments + bos_text (str): BOS text + eos_text (str): EOS text + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers + + Raises: + KeyError: If constants are not defined for the split + """ + try: + dataset_constants = CONSTS[dataset] + except KeyError: + raise ValueError( + f'Constants for dataset "{dataset}" not found. Currently only "the_pile" and "c4" are supported.', + ) + + if concat_tokens is not None and tokenizer is not None: + mode = ConcatMode.CONCAT_TOKENS + built_tokenizer = build_tokenizer(tokenizer, tokenizer_kwargs) + # we will enforce length, so suppress warnings about sequences too long for the model + built_tokenizer.model_max_length = int(1e30) + columns = {'tokens': 'ndarray:int32'} + else: + mode = ConcatMode.NO_CONCAT + built_tokenizer = None + columns = {'text': 'str'} + + for split_name in splits: + try: + split = dataset_constants.splits[split_name] + except KeyError: + raise KeyError(f'Constants not defined for split {split_name}.') + hf_split = split.hf_split + folder_split = split.folder_split + expected_num_samples = split.raw_samples + truncate_num_samples = split.truncated_samples + # Only generate the splits requested + if folder_split not in splits: + continue + + # Get samples + hf_dataset = build_hf_dataset( + dataset_name=dataset, + data_subset=data_subset, + split=hf_split, + mode=mode, + max_length=concat_tokens, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + tokenizer=built_tokenizer, + ) + loader = build_dataloader( + dataset=hf_dataset, + batch_size=512, + num_workers=num_workers, + ) + samples = generate_samples( + loader, + truncate_num_samples=truncate_num_samples, + ) + + if expected_num_samples is not None and concat_tokens is not None: + denominator = truncate_num_samples if truncate_num_samples is not None else _est_progress_denominator( + total_samples=expected_num_samples, + chars_per_sample=dataset_constants.chars_per_sample, + chars_per_token=dataset_constants.chars_per_token, + mode=mode, + max_length=concat_tokens, + ) + else: + denominator = None + + # Write samples + print(f'Converting {folder_split} to MDS format...') + print( + f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.', + ) + with MDSWriter( + columns=columns, + out=os.path.join(out_root, folder_split), + compression=compression, + ) as out: + if denominator is not None: + for sample in tqdm( + samples, + desc=folder_split, + total=denominator, + ): + out.write(sample) + else: + for sample in tqdm(samples, desc=folder_split): + out.write(sample) + + +def convert_dataset_hf_from_args( + dataset: str, + data_subset: Optional[str], + splits: list[str], + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: Optional[str], + bos_text: Optional[str], + eos_text: Optional[str], + no_wrap: bool, + num_workers: Optional[int], +) -> None: + """A wrapper for `convert_dataset_hf` that parses arguments. + + Args: + dataset (str): Name of the dataset + data_subset (Optional[str]): Subset of the dataset (e.g., "all" or "en") + splits (list[str]): Comma-separated list of dataset splits + out_root (str): Output root directory + compression (Optional[str]): Compression type + concat_tokens (Optional[int]): Concatenate tokens up to this many tokens + tokenizer (Optional[str]): Tokenizer name + tokenizer_kwargs (Optional[str]): Tokenizer keyword arguments in JSON format + bos_text (Optional[str]): BOS text + eos_text (Optional[str]): EOS text + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers + + Raises: + ValueError: If the output directory already contains the requested splits + ValueError: If `concat_tokens` is set but `tokenizer` is not + """ + if tokenizer_kwargs: + parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs) + else: + parsed_tokenizer_kwargs = {} + + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(splits)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {splits}.', + ) + + # Make sure we have needed concat options + if ( + concat_tokens is not None and isinstance(concat_tokens, int) and + tokenizer is None + ): + raise ValueError( + 'When setting --concat_tokens, you must specify a --tokenizer', + ) + + # now that we have validated them, change BOS/EOS to strings and convert + convert_dataset_hf( + dataset=dataset, + data_subset=data_subset, + splits=splits, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + tokenizer_kwargs=parsed_tokenizer_kwargs, + bos_text=bos_text if bos_text else '', + eos_text=eos_text if eos_text else '', + no_wrap=no_wrap, + num_workers=num_workers, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py new file mode 100644 index 0000000000..35d7e637e6 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -0,0 +1,222 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming dataset conversion scripts for json files.""" +import os +from enum import Enum +from glob import glob +from typing import Optional + +import datasets as hf_datasets +from streaming import MDSWriter +from torch.utils.data import IterableDataset +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.data import ConcatTokensDataset, NoConcatDataset + + +class ConcatMode(Enum): + NO_CONCAT = 'NO_CONCAT' + CONCAT_TOKENS = 'CONCAT_TOKENS' + + +def build_hf_dataset( + path: str, + split: str, + mode: ConcatMode, + max_length: Optional[int] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + tokenizer: PreTrainedTokenizerBase = None, +) -> IterableDataset: + """Build an IterableDataset over the HF C4 or pile source data. + + Args: + path (str): Dataset name + split (str): Split name. + mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS + max_length (int): The length of concatenated tokens + bos_text (str): text to insert at the beginning of each sequence + eos_text (str): text to insert at the end of each sequence + no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries + tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use + data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. + Typically "all" (The Pile) or "en" (c4). + + Returns: + An IterableDataset. + """ + if os.path.isdir(path): + data_files = glob(f'{path}/*') + else: + data_files = path + + hf_dataset = hf_datasets.load_dataset( + 'json', + data_files=data_files, + split=split, + ) + + if mode == ConcatMode.NO_CONCAT: + dataset = NoConcatDataset(hf_dataset) + else: + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) + if max_length is None: + raise ValueError(f'max_length must be set.') + if bos_text + eos_text == '': + test_tokens = tokenizer('test') + if test_tokens['input_ids'][ + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: + tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' + tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' + tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' + tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' + tok_error_msg += '--bos_text=<|endoftext|>.' + raise ValueError(tok_error_msg) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) + return dataset + + +def convert_dataset_json( + path: str, + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + split: str, + tokenizer: Optional[str] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + num_workers: Optional[int] = None, +) -> None: + """Create C4/pile streaming dataset. + + Args: + path (str): Path to the input data file + out_root (str): Output root directory + compression (Optional[str]): Compression type, if any + concat_tokens (Optional[int]): Convert text to tokens and concatenate up to this many tokens + split (str): Dataset split to process + tokenizer (Optional[str]): Tokenizer name + bos_text (str): Text to insert at the beginning of each sequence + eos_text (str): Text to insert at the end of each sequence + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers for data loading + """ + if concat_tokens is not None: + mode = ConcatMode.CONCAT_TOKENS + built_tokenizer = AutoTokenizer.from_pretrained(tokenizer) + # we will enforce length, so suppress warnings about sequences too long for the model + built_tokenizer.model_max_length = int(1e30) + columns = {'tokens': 'ndarray:int32'} + else: + mode = ConcatMode.NO_CONCAT + built_tokenizer = None + columns = {'text': 'str'} + + # Get samples + dataset = build_hf_dataset( + path=path, + split=split, + mode=mode, + max_length=concat_tokens, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + tokenizer=built_tokenizer, + ) + + print('here') + + # Write samples + print(f'Converting to MDS format...') + print( + f'Note that the progress bar is based on the dataset length before tokenization.', + ) + print(f'It will finish at a value below 100% if tokenizing') + with MDSWriter( + columns=columns, + out=os.path.join(out_root), + compression=compression, + ) as out: + for sample in tqdm(dataset): + out.write(sample) + + +def convert_dataset_json_from_args( + path: str, + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + split: str, + tokenizer: Optional[str] = None, + bos_text: Optional[str] = None, + eos_text: Optional[str] = None, + no_wrap: bool = False, + num_workers: Optional[int] = None, +) -> None: + """A wrapper for `convert_dataset_json` that parses arguments. + + Args: + path (str): Path to the input data file + out_root (str): Output root directory + compression (Optional[str]): Compression type, if any + concat_tokens (Optional[int]): Convert text to tokens and concatenate up to this many tokens + split (str): Dataset split to process + tokenizer (Optional[str]): Tokenizer name + bos_text (Optional[str]): Text to insert at the beginning of each sequence + eos_text (Optional[str]): Text to insert at the end of each sequence + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers for data loading + + Raises: + ValueError: If the out_root directory exists and contains files that overlap with the requested splits + ValueError: If concat_tokens is set and a tokenizer is not provided + """ + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(split)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {split}.', + ) + + # Make sure we have needed concat options + if ( + concat_tokens is not None and isinstance(concat_tokens, int) and + tokenizer is None + ): + ValueError( + 'When setting --concat_tokens, you must specify a --tokenizer', + ) + + # now that we have validated them, change BOS/EOS to strings + if bos_text is None: + bos_text = '' + if eos_text is None: + eos_text = '' + + convert_dataset_json( + path=path, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + split=split, + tokenizer=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..635efd54d4 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -0,0 +1,762 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import time +import urllib.parse +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +import google.protobuf.any_pb2 as any_pb2 +import pandas as pd +import pyarrow as pa +import requests +from composer.utils import retry +from packaging import version + +from llmfoundry.utils.exceptions import ( + ClusterDoesNotExistError, + FailedToConnectToDatabricksError, + FailedToCreateSQLConnectionError, +) + +if TYPE_CHECKING: + import pyspark.sql.connect.proto as pb2 + from databricks.sql.client import Connection as Connection + from databricks.sql.client import Cursor as Cursor + from pyspark.sql import SparkSession + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.types import Row + +try: + from pyspark.sql.connect.client.core import SparkConnectClient + spark_connect_client_installed = True +except ImportError: + spark_connect_client_installed = False + +try: + from pyspark.sql.connect.dataframe import DataFrame + data_frame_installed = True +except ImportError: + data_frame_installed = False + +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' + +TABLENAME_PATTERN = re.compile(r'(\S+)\.(\S+)\.(\S+)') + +log = logging.getLogger(__name__) + +Result = namedtuple( + 'Result', + [ + 'url', + 'row_count', + 'compressed_size', + 'uncompressed_size', + ], +) # pyright: ignore + +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. + + +def to_cf(self: 'SparkConnectClient', + plan: 'pb2.Plan', + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Executes the query plans and return as presigned URLS for cloud fetch. + + It can handle the current output formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + + Args: + self (SparkConnectClient): The SparkConnectClient we are processing. + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result has been truncated. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + import pyspark.sql.connect.proto as pb2 + import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise ValueError( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}', + ) + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + ), + ) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), + ) + + # Create the iterator + from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + ) + iterator = ExecutePlanResponseReattachableIterator( + req, + self._stub, + self._retry_policy, + self._builder.metadata(), + ) + # Iterate over the response + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR, + ): + batch = cloud_pb2.CloudResultBatch() + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError( + 'Response extension is not of type CloudResultBatch.', + ) + response.extension.Unpack(batch) + result += [ + Result( + b.url, + b.row_count, + b.compressed_size, + b.uncompressed_size, + ) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +if spark_connect_client_installed: + SparkConnectClient.to_cf = to_cf # pyright: ignore + + +def collect_as_cf(self: 'DataFrame', + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects DataFrame execution plan as presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + self (pd.DataFrame): The dataframe we are processing. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +if data_frame_installed: + DataFrame.collect_cf = collect_as_cf # pyright: ignore + + +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: + """Combine jsonl files in json_directory into one big jsonl file. + + This function does not work for nested subdirectories. + + Args: + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) + log.info('JSON files have been combined into a JSONL file.') + + +def run_query( + query: str, + method: str, + cursor: Optional['Cursor'] = None, + spark: Optional['SparkSession'] = None, + collect: bool = True, +) -> Optional[Union[List['Row'], 'DataFrame', 'SparkDataFrame']]: + """Run SQL query via databricks-connect or databricks-sql. + + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ + if method == 'dbsql': + if cursor is None: + raise ValueError(f'cursor cannot be None if using method dbsql') + cursor.execute(query) + if collect: + return cursor.fetchall() + elif method == 'dbconnect': + if spark == None: + raise ValueError(f'sparkSession is required for dbconnect') + df = spark.sql(query) + if collect: + return df.collect() + return df + else: + raise ValueError(f'Unrecognized method: {method}') + + +def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: + for i, r in enumerate(signed): + yield (i, r.url, json_output_folder, columns) + + +def download( + ipart: int, + url: str, + json_output_folder: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False, +) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_folder (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ + resp = requests.get(url) + if resp.status_code == 200: + if resp_format == 'json': + data = resp.json() + pd.DataFrame(data, columns=columns).to_json( + os.path.join( + json_output_folder, + 'part_' + str(ipart) + '.jsonl', + ), + orient='records', + lines=True, + ) + return + + # When resp_format is arrow: + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + import lz4.frame + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json( + os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False, + ) + + +def download_starargs(args: Tuple) -> None: + return download(*args) + + +def format_tablename(table_name: str) -> str: + """Escape catalog, schema and table names with backticks. + + This needs to be done when running SQL queries/setting spark sessions to prevent invalid identifier errors. + + Args: + table_name (str): catalog.scheme.tablename on UC + """ + match = re.match(TABLENAME_PATTERN, table_name) + + if match is None: + return table_name + + formatted_identifiers = [] + for i in range(1, 4): + identifier = f'`{match.group(i)}`' + formatted_identifiers.append(identifier) + + return '.'.join(formatted_identifiers) + + +def fetch_data( + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], + start: int, + end: int, + order_by: str, + tablename: str, + columns_str: str, + json_output_folder: str, +) -> None: + """Fetches a specified range of rows from a given table to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_folder (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + if method == 'dbconnect': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if spark_df is None: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None', + ) + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + if ans is None: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json( + os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True, + ) + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_total_rows( + tablename: str, + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], +): + ans = run_query( + f'SELECT COUNT(*) FROM {tablename}', + method, + cursor, + sparkSession, + ) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + return nrows + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_columns_info( + tablename: str, + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], +): + ans = run_query( + f'SHOW COLUMNS IN {tablename}', + method, + cursor, + sparkSession, + ) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + return columns, order_by, columns_str + + +def fetch( + method: str, + tablename: str, + json_output_folder: str, + batch_size: int = 1 << 30, + processes: int = 1, + sparkSession: Optional['SparkSession'] = None, + dbsql: Optional['Connection'] = None, +) -> None: + """Fetch UC delta table with databricks-connect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_folder (str): path to write the result json file to + batch_size (int): number of rows that dbsql fetches each time to avoid OOM + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session + """ + cursor = dbsql.cursor() if dbsql is not None else None + try: + nrows = get_total_rows( + tablename, + method, + cursor, + sparkSession, + ) + except Exception as e: + raise RuntimeError( + f'Error in get rows from {tablename}. Restart sparkSession and try again', + ) from e + + try: + columns, order_by, columns_str = get_columns_info( + tablename, + method, + cursor, + sparkSession, + ) + except Exception as e: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again', + ) from e + + if method == 'dbconnect' and sparkSession is not None: + log.info(f'{processes=}') + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) + + if cursor is not None: + cursor.close() + + +def validate_and_get_cluster_info( + cluster_id: Optional[str], + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False, +) -> tuple: + """Validate and get cluster info for running the Delta to JSONL conversion. + + Args: + cluster_id (str): cluster id to validate and fetch additional info for + databricks_host (str): databricks host name + databricks_token (str): databricks auth token + http_path (Optional[str]): http path to use for sql connect + use_serverless (bool): whether to use serverless or not + """ + method = 'dbsql' + dbsql = None + sparkSession = None + + if use_serverless: + method = 'dbconnect' + else: + if not cluster_id: + raise ValueError( + 'cluster_id is not set, however use_serverless is False', + ) + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + res = w.clusters.get(cluster_id=cluster_id) + if res is None: + raise ClusterDoesNotExistError(cluster_id) + + assert res.spark_version is not None + stripped_runtime = re.sub( + r'[a-zA-Z]', + '', + res.spark_version.split('-scala') + [0].replace( # type: ignore + 'x-snapshot', '', + ), + ) + runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) + if version.parse( + runtime_version, + ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', + ) + + if http_path is None and version.parse( + runtime_version, + ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + from databricks.connect import DatabricksSession + try: + if use_serverless: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + databricks_host, + ).token( + databricks_token, + ).header('x-databricks-session-id', session_id).getOrCreate() + + else: + if not cluster_id: + raise ValueError('cluster_id is needed for dbconnect.',) + sparkSession = DatabricksSession.builder.remote( + host=databricks_host, + token=databricks_token, + cluster_id=cluster_id, + ).getOrCreate() + + except Exception as e: + raise FailedToConnectToDatabricksError() from e + else: + try: + from databricks import sql + dbsql = sql.connect( + server_hostname=re.compile(r'^https?://').sub( + '', databricks_host).strip( + ), # sqlconnect hangs if hostname starts with https + http_path=http_path, + access_token=databricks_token, + ) + except Exception as e: + raise FailedToCreateSQLConnectionError() from e + return method, dbsql, sparkSession + + +def fetch_DT( + delta_table_name: str, + json_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + DATABRICKS_HOST: str, + DATABRICKS_TOKEN: str, + batch_size: int = 1 << 30, + processes: int = os.cpu_count(), # type: ignore + json_output_filename: str = 'train-00000-of-00001.jsonl', +) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(json_output_folder) + if obj.scheme != '': + raise ValueError( + 'Check the json_output_folder and verify it is a local path!', + ) + + if os.path.exists(json_output_folder): + if not os.path.isdir(json_output_folder) or os.listdir( + json_output_folder, + ): + raise RuntimeError( + f'Output folder {json_output_folder} already exists and is not empty. Please remove it and retry.', + ) + + os.makedirs(json_output_folder, exist_ok=True) + + if not json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {json_output_folder} created.') + + # validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=cluster_id, + databricks_host=DATABRICKS_HOST, + databricks_token=DATABRICKS_TOKEN, + http_path=http_path, + use_serverless=use_serverless, + ) + + formatted_delta_table_name = format_tablename(delta_table_name) + + fetch( + method, + formatted_delta_table_name, + json_output_folder, + batch_size, + processes, + sparkSession, + dbsql, + ) + + if dbsql is not None: + dbsql.close() + + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + json_output_folder, + os.path.join(json_output_folder, json_output_filename), + ) + + +def _check_imports(): + try: + import lz4.frame + _ = lz4.frame + except ImportError as e: + raise ImportError('lz4 is not installed.') from e + + try: + from databricks.connect import DatabricksSession + _ = DatabricksSession + except ImportError as e: + raise ImportError( + 'databricks-connect is not installed or improperly configured.', + ) from e + + try: + from databricks import sql + from databricks.sdk import WorkspaceClient + from databricks.sql.client import Connection as Connection + from databricks.sql.client import Cursor as Cursor + _ = WorkspaceClient, Connection, Cursor, sql + except ImportError as e: + raise ImportError( + 'databricks-sdk is not installed or improperly configured.', + ) from e + + try: + import pyspark.sql.connect.proto as pb2 + import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 + from pyspark.sql import SparkSession + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + ) + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.types import Row + _ = ( + pb2, + cloud_pb2, + SparkSession, + SparkConnectClient, + ExecutePlanResponseReattachableIterator, + DataFrame, + SparkDataFrame, + Row, + ) + except ImportError as e: + raise ImportError( + 'pyspark is not installed or improperly configured.', + ) from e + + +def convert_delta_to_json_from_args( + delta_table_name: str, + json_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + batch_size: int, + processes: int, + json_output_filename: str, +) -> None: + """A wrapper for `convert_dataset_json` that parses arguments. + + Args: + delta_table_name (str): UC table ..
+ json_output_folder (str): Local path to save the converted json + http_path (Optional[str]): If set, dbsql method is used + batch_size (int): Row chunks to transmit a time to avoid OOM + processes (int): Number of processes allowed to use + cluster_id (Optional[str]): Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect. + use_serverless (bool): Use serverless or not. Make sure the workspace is entitled with serverless + json_output_filename (str): The name of the combined final jsonl that combines all partitioned jsonl + """ + _check_imports() + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + DATABRICKS_HOST = w.config.host + DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_output_filename, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + log.info(f'Elapsed time {time.time() - tik}') diff --git a/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py new file mode 100644 index 0000000000..94cd79815b --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py @@ -0,0 +1,346 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import platform +import warnings +from typing import Any, Callable, Dict, Iterable, Optional, Union + +import datasets as hf_datasets +import psutil +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +from streaming import MDSWriter +from torch.utils.data import DataLoader +from tqdm import tqdm + +from llmfoundry.data.finetuning.collator import validate_target_settings +from llmfoundry.data.finetuning.tasks import ( + _get_example_type, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example, +) +from llmfoundry.utils.builders import build_tokenizer + +HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] + + +def build_dataloader( + dataset: HFDataset, + batch_size: int, + num_workers: Optional[int] = None, +) -> DataLoader: + if num_workers is None: + # Multiple workers is only supported on linux machines + if 'linux' in platform.platform().lower(): + num_workers = max(1, psutil.cpu_count()) + else: + num_workers = 0 + + # If using multiple workers, configure each worker to prefetch as many samples as it can, up to + # the aggregate device batch size + # If not using workers, the torch DataLoader expects the default value for prefetch_factor, + # which non-intuitively must be 2. + # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero + if 'macos' in platform.platform().lower() and num_workers == 0: + prefetch_factor = None + else: + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 + + return DataLoader( + dataset=dataset, + sampler=None, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + +def generate_samples( + loader: DataLoader, + truncate_num_samples: Optional[int] = None, +) -> Iterable[Dict[str, bytes]]: + """Generator over samples of a dataloader. + + Args: + loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} + truncate_num_samples (Optional[int]): An optional # of samples to stop at. + + Yields: + Sample dicts. + """ + n_samples = 0 + for batch in loader: + keys = list(batch.keys()) + current_bs = len(batch[keys[0]]) + for idx in range(current_bs): + if truncate_num_samples is not None and n_samples == truncate_num_samples: + return + n_samples += 1 + yield {k: v[idx] for k, v in batch.items()} + + +def get_columns_and_format( + dataset: HFDataset, + tokenizing: bool, + preprocessing_fn: Callable, +): + ex = preprocessing_fn(next(iter(dataset))) + example_type = _get_example_type(ex) + if tokenizing: + return {'turns': 'json'}, example_type + if example_type == 'chat': + # Chat format + return {'messages': 'json'}, example_type + else: + # Prompt-response format + return {'prompt': 'str', 'response': 'str'}, example_type + + +def convert_finetuning_dataset( + dataset: str, + data_subset: Optional[str], + splits: list[str], + preprocessor: Optional[str], + data_files: list[str], + skip_preprocessing: bool, + out_root: str, + local: Optional[str], + compression: Optional[str], + num_workers: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: dict[str, Any], + max_seq_len: int, + target_prompts: str, + target_responses: str, + encoder_decoder: bool, +) -> None: + """Converts Finetuning datasets to MDS format. + + Args: + dataset (str): Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`). + data_subset (Optional[str]): Subset of data to use. + splits (list[str]): Comma-separated list of dataset splits + preprocessor (Optional[str]): Name or import path of function used to preprocess (reformat) the dataset. + data_files (list[str]): Data file for each split. Comma-separated. + skip_preprocessing (bool): Whether to skip preprocessing. + out_root (str): Root path of output directory where MDS shards will be stored. Can be a remote URI. + local (Optional[str]): Root path of local directory if you want to keep a local copy when out_root is remote. + compression (Optional[str]): Name of compression algorithm to use. + num_workers (Optional[int]): Number of workers. + tokenizer (Optional[str]): Tokenizer used for processing. + tokenizer_kwargs (dict[str, Any]): Keyword arguments for tokenizer initialization. + max_seq_len (int): Maximum sequence length. + target_prompts (str): Policy for when to use prompts as training targets. + target_responses (str): Policy for which responses to treat as training targets. + encoder_decoder (bool): Set if the data are intended to be used to train an encoder-decoder model + + Raises: + ValueError: If the target settings are invalid. + """ + if skip_preprocessing: + preprocessing_fn = lambda x: x # Just an identity function + else: + preprocessor_str = preprocessor + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( + preprocessor=preprocessor_str, + dataset_name=dataset, + ) + if preprocessing_fn is None: + raise ValueError( + '`preprocessor` was not set and no preprocessing function ' +\ + 'has been registered for `dataset`. If this was intentional ' +\ + '(e.g., because your dataset is already correctly formatted), ' +\ + 'include the "--skip-preprocessing" flag to avoid this error.', + ) + + # Make sure the target settings are valid + validate_target_settings( + target_prompts=target_prompts, + target_responses=target_responses, + decoder_only_format=not encoder_decoder, + ) + + tokenizer = None + tokenizer_kwargs = tokenizer_kwargs + tokenizer_kwargs.update({'model_max_length': max_seq_len}) + if tokenizer: + tokenizer = build_tokenizer(tokenizer, tokenizer_kwargs) + + for i, split_name in enumerate(splits): + data_file = None + if len(data_files) > 0: + data_file = data_files[i] + loaded_dataset = hf_datasets.load_dataset( + path=dataset, + name=data_subset, + split=split_name, + data_files=data_file, + streaming=True, + ) + # Determine the output columns + columns, example_type = get_columns_and_format( + dataset=loaded_dataset, + tokenizing=tokenizer is not None, + preprocessing_fn=preprocessing_fn, + ) + # Prepare the iterables + if example_type == 'chat': + samples = iter(loaded_dataset) + else: + loader = build_dataloader( + dataset=loaded_dataset, + batch_size=512, + num_workers=num_workers, + ) + samples = generate_samples(loader) + + # Write samples + print(f'Converting {split_name} to MDS format...') + out = os.path.join(out_root, split_name) + if local is not None: + out = (os.path.join(local, split_name), out) + keep_local = True + else: + keep_local = False + with MDSWriter( + columns=columns, + out=out, + compression=compression, + keep_local=keep_local, + ) as out: + examples_removed = 0 + for sample in tqdm(samples, desc=split_name): + formatted_sample = preprocessing_fn(sample) + assert isinstance(formatted_sample, dict) + + # Use the _get_example_type utility to confirm that the formatted sample + # can be interpreted by the tokenization code + try: + example_type = _get_example_type(formatted_sample) + except Exception as e: + raise ValueError( + 'Encountered an error when checking example for proper formatting. ' +\ + f'example={formatted_sample}', + ) from e + if tokenizer is not None: + sample = tokenize_formatted_example( + formatted_sample, + tokenizer=tokenizer, + ) + if not is_valid_ift_example( + max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + decoder_only_format=not encoder_decoder, + example=sample, + ): + examples_removed += 1 + continue + + sample_to_write = {'turns': []} + for turn in sample['turns']: + turn_to_write = {} + for key in ['input_ids', 'labels']: + turn_to_write[key] = list(turn[key]) + sample_to_write['turns'].append(turn_to_write) + out.write(sample_to_write) + else: + if example_type == 'prompt_response': + encoded_sample = {} + for key in ['prompt', 'response']: + value = formatted_sample[key] + assert isinstance(value, str) + encoded_sample[key] = value.encode('utf-8') + out.write(encoded_sample) + else: + out.write(formatted_sample) + + if tokenizer is not None and examples_removed > 0: + warnings.warn( + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.', + ) + + +def convert_finetuning_dataset_from_args( + dataset: str, + data_subset: Optional[str], + splits: list[str], + preprocessor: Optional[str], + data_files: list[str], + skip_preprocessing: bool, + out_root: str, + local: Optional[str], + compression: Optional[str], + num_workers: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: Optional[str], + max_seq_len: int, + target_prompts: str, + target_responses: str, + encoder_decoder: bool, +): + """A wrapper for `convert_finetuning_dataset` to parse arguments. + + Args: + dataset (str): Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`). + data_subset (Optional[str]): Subset of data to use. + splits (list[str]): Comma-separated list of dataset splits + preprocessor (Optional[str]): Name or import path of function used to preprocess (reformat) the dataset. + data_files (list[str]): Data file for each split. Comma-separated. + skip_preprocessing (bool): Whether to skip preprocessing. + out_root (str): Root path of output directory where MDS shards will be stored. Can be a remote URI. + local (Optional[str]): Root path of local directory if you want to keep a local copy when out_root is remote. + compression (Optional[str]): Name of compression algorithm to use. + num_workers (Optional[int]): Number of workers. + tokenizer (Optional[str]): Tokenizer used for processing. + tokenizer_kwargs (Optional[str]): Keyword arguments for tokenizer initialization in JSON format. + max_seq_len (int): Maximum sequence length. + target_prompts (str): Policy for when to use prompts as training targets. + target_responses (str): Policy for which responses to treat as training targets. + encoder_decoder (bool): Set if the data are intended to be used to train an encoder-decoder model. + + Raises: + ValueError: If the target settings are invalid. + ValueError: If the output directory already contains the requested splits. + """ + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(splits)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {splits}.', + ) + + if tokenizer_kwargs is not None: + parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs) + else: + parsed_tokenizer_kwargs = {} + + if len(data_files) > 0 and len(data_files,) != len(splits): + raise ValueError( + f'If data_files is set, data_files and splits must have the same length. Got {len(data_files)=} while {len(splits)=}', + ) + convert_finetuning_dataset( + dataset=dataset, + data_subset=data_subset, + splits=splits, + preprocessor=preprocessor, + data_files=data_files, + skip_preprocessing=skip_preprocessing, + out_root=out_root, + local=local, + compression=compression, + num_workers=num_workers, + tokenizer=tokenizer, + tokenizer_kwargs=parsed_tokenizer_kwargs, + max_seq_len=max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + encoder_decoder=encoder_decoder, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py new file mode 100644 index 0000000000..94bdc16526 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -0,0 +1,596 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import math +import os +import tempfile +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from glob import glob +from typing import Dict, Iterable, List, Optional, Tuple, cast + +import numpy as np +from composer.utils import ( + ObjectStore, + maybe_create_object_store_from_uri, + parse_uri, +) +from numpy.typing import NDArray +from streaming import MDSWriter +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.data.data import AbstractConcatTokensDataset +from llmfoundry.utils.data_prep_utils import ( + DownloadingIterable, + download_file, + merge_shard_groups, +) +from llmfoundry.utils.exceptions import ( + DatasetTooSmallError, + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) + +log = logging.getLogger(__name__) + +DONE_FILENAME = '.text_to_mds_conversion_done' + + +class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): + """An IterableDataset that returns token samples for MDSWriter from files. + + Returns dicts of {'tokens': ndarray:int32} + + Each file is considered a sequence. + """ + + def __init__( + self, + files: Iterable[str], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + bos_text: str, + eos_text: str, + no_wrap: bool, + ): + self.files = files + super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) + log.info(f'Initialized ConcatTokensFromFilesDataset.') + + def __iter__(self) -> Iterable[Dict[str, NDArray]]: + log.info( + 'Starting iteration over files in ConcatTokensFromFilesDataset', + ) + buffer = [] + for file in self.files: + log.info(f'Processing file: {file}') + with open(file, 'r') as f: + buffer += self.bos_tokens + first_chunk = True + # Read the file in 1MB chunks to avoid memory issues + for chunk in iter(partial(f.read, 1000000), ''): + # Tokenize the chunk + encoded = self.tokenizer( + chunk, + truncation=False, + padding=False, + ) + iids = encoded['input_ids'] + + # If this is not the first chunk, remove the BOS token + if not first_chunk: + if iids[0] == self.tokenizer.bos_token_id: + iids = iids[1:] + + # Add the tokens to the buffer + buffer += iids + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self. + max_length:] if self.should_wrap else [] + yield { + 'tokens': np.asarray(concat_sample, dtype=np.int32), + } + + first_chunk = False + + # Add the EOS token to the buffer to separate files. + buffer += self.eos_tokens + + # Yield any remaining samples of size max_length. + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self.max_length:] if self.should_wrap else [] + yield {'tokens': np.asarray(concat_sample, dtype=np.int32)} + + log.info( + 'Finished iterating over files in ConcatTokensFromFilesDataset', + ) + + +def get_object_names(input_folder: str) -> List[str]: + """Get object names from a local or remote folder. + + Args: + input_folder (str): local or remote folder path. + """ + object_store = maybe_create_object_store_from_uri(input_folder) + if object_store is not None: + _, _, folder_prefix = parse_uri(input_folder) + names = [ + name for name in object_store.list_objects(folder_prefix) + if name.endswith('.txt') + ] + log.info(f'Found {len(names)} text files in remote storage') + else: + # input_folder is a local folder + names = [ + text_file for dirpath, _, _ in os.walk(input_folder) + for text_file in glob(os.path.join(dirpath, '*.txt')) + ] + # return names, sizes + log.info(f'Found {len(names)} text files at {input_folder}') + + return names + + +def get_task_args( + object_names: List[str], + output_root: str, + input_folder: str, + n_groups: int, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +) -> Iterable: + """Get download_and_convert arguments split across n_groups. + + Each group handles a portion of object_names. + + Args: + object_names (List[str]): Names of objects to process + output_root (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + n_groups (int): Number of groups to split the object names into + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info( + f'Preparing task arguments for {len(object_names)} objects across {n_groups} groups', + ) + num_objects = len(object_names) + objs_per_group = math.ceil(num_objects / n_groups) + for group, i in enumerate(range(0, num_objects, objs_per_group)): + output_subdir = os.path.join(output_root, str(group)) + log.info( + f'Created task for group {group} with {min(objs_per_group, num_objects - i)} objects', + ) + yield ( + object_names[i:min(i + objs_per_group, num_objects)], + output_subdir, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + +def download_and_convert_starargs(args: Tuple): + """Helper function to call download_and_convert with star args. + + This helps us use download_and_convert with multiprocessing. + """ + return download_and_convert(*args) + + +def download_and_convert( + file_names: List[str], + output_folder: str, + input_folder: str, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +): + """Downloads and converts text files to MDS format. + + Args: + file_names (List[str]): Files to process + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info(f'Starting download and conversion for {len(file_names)} files') + + object_store = maybe_create_object_store_from_uri(input_folder) + + # Download file_names + with tempfile.TemporaryDirectory() as tmp_dir: + log.info(f'Created temporary directory: {tmp_dir}') + downloading_iter = DownloadingIterable( + object_names=file_names, + output_folder=tmp_dir, + object_store=object_store, + ) + log.info(f'Initializing tokenizer: {tokenizer_name}') + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + + # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up + # to the maximum sequence length + dataset = ConcatTokensFromFilesDataset( + files=downloading_iter, + max_length=concat_tokens, + tokenizer=tokenizer, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + ) + + columns = {'tokens': 'ndarray:int32'} + + log.info('Converting to MDS format...') + with MDSWriter( + out=output_folder, + columns=columns, + compression=compression, + ) as out: + for sample in tqdm(dataset): + out.write(sample) + + log.info(f'Completed download and conversion for {len(file_names)} files') + + +def is_remote_path(path: str) -> bool: + """Checks whether a path is a remote path. + + Args: + path (str): path to check + """ + backend, _, _ = parse_uri(path) + return backend != '' + + +def is_already_processed( + output_root: str, + args_str: str, + object_names: List[str], +) -> bool: + """Determines whether a group of text files has already been processed. + + Checks the done fie at output root to determine this. + + Args: + output_root (str): Output folder where a done file may exist + args_str (str): String representation of the arguments + object_names (List[str]): Names of objects to convert to MDS format + """ + log.info( + f'Checking if {len(object_names)} objects have already been processed in {output_root}', + ) + + # Retrieve the done file contents + output_object_store = maybe_create_object_store_from_uri(output_root) + if output_object_store is not None: + # Download and read the done file from the remote object store + _, _, output_folder_prefix = parse_uri(output_root) + try: + with tempfile.TemporaryDirectory() as tmp_dir: + done_file = os.path.join(tmp_dir, DONE_FILENAME) + download_file( + object_store=output_object_store, + object_name=os.path.join( + output_folder_prefix, + DONE_FILENAME, + ), + output_filename=done_file, + ) + with open(done_file) as df: + done_file_contents = df.read().splitlines() + log.info(f'Retrieved done file contents from remote storage') + except FileNotFoundError: + log.info('Done file not found in remote storage') + return False + else: + # Read the local done file + done_file = os.path.join(output_root, DONE_FILENAME) + if not os.path.isfile(done_file): + log.info('Done file not found in local storage') + return False + with open(done_file) as df: + done_file_contents = df.read().splitlines() + log.info(f'Retrieved done file contents from local storage') + + # Compare the arguments + prev_args_str = done_file_contents[0] + if prev_args_str != args_str: + log.info('Arguments have changed, reprocessing required') + return False + + # Compare file names + prev_names = done_file_contents[1:] + if len(prev_names) != len(object_names): + log.info('Number of files has changed, reprocessing required') + return False + for idx, prev_name in enumerate(prev_names): + if object_names[idx] != prev_name: + log.info('File names have changed, reprocessing required') + return False + + log.info('All files have already been processed') + return True + + +def write_done_file(folder: str, args_str: str, object_names: List[str]): + """Write a file to signify completion. + + This the done file includes the arguments to processing and + a list of objects that were processed. + + Args: + folder (str): Folder to write the done file to + args_str (str): String representation of arguments + object_names (List[str]): List of objects to convert to MDS format + """ + with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: + log.info(f'Writing done file.') + done_file.write('\n'.join([args_str] + object_names) + '\n') + log.info(f'Done file written successfully') + + +def convert_text_to_mds( + tokenizer_name: str, + output_folder: str, + input_folder: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + processes: int, + args_str: str, + reprocess: bool, + trust_remote_code: bool, +): + """Convert a folder of text files to MDS format. + + Args: + tokenizer_name (str): Name of tokenizer to use + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + processes (int): The number of processes to use. + args_str (str): String representation of the arguments + reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + # Load the tokenizer once on the main process so that the files are cached to avoid race conditions + # in the Hugging Face load code + AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + + is_remote_output = is_remote_path(output_folder) + log.info(f'Output is remote: {is_remote_output}') + + object_names = get_object_names(input_folder) + if len(object_names) == 0: + log.error(f'No text files found in input folder: {input_folder}') + raise InputFolderMissingDataError(input_folder) + + # Check if the text files in the bucket have already been processed. + if not reprocess and is_already_processed( + output_folder, + args_str, + object_names, + ): + log.info( + f'Input folder {input_folder} is already processed at {output_folder} and ' + + + 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', + ) + return + + # Use a temporary local directory if the output is remote and there are more than 1 processes + local_output_folder = tempfile.TemporaryDirectory( + ).name if is_remote_output else output_folder + log.info(f'Using local output folder: {local_output_folder}') + + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + log.error(f'Output folder is not empty: {output_folder}') + raise OutputFolderNotEmptyError(output_folder) + + if processes > 1: + log.info(f'Using multiprocessing with {processes} processes') + # Download and convert the text files in parallel + args = get_task_args( + object_names, + local_output_folder, + input_folder, + processes, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_and_convert_starargs, args)) + + log.info('Merging MDS shards from each process') + # Merge the mds shards from each of the processes into a single folder + merge_shard_groups(local_output_folder) + else: + log.info('Using single process for download and conversion') + download_and_convert( + object_names, + local_output_folder, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + index_path = os.path.join(local_output_folder, 'index.json') + with open(index_path, 'r') as index_file: + if not json.load(index_file)['shards']: + raise DatasetTooSmallError() + + # Write a done file with the args and object names + write_done_file(local_output_folder, args_str, object_names) + + if is_remote_output: + # Upload the local output to the remote location + output_object_store = cast( + ObjectStore, + maybe_create_object_store_from_uri(output_folder), + ) + _, _, output_folder_prefix = parse_uri(output_folder) + files_to_upload = os.listdir(local_output_folder) + + for file in files_to_upload: + assert not os.path.isdir(file) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object( + remote_path, + os.path.join(local_output_folder, file), + ) + + +def _configure_logging(logging_level: str): + """Configure logging. + + Args: + logging_level (str): Logging level. + """ + logging.basicConfig( + format= + f'%(asctime)s: [%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging_level = logging_level.upper() + logging.getLogger('llmfoundry').setLevel(logging_level) + logging.getLogger(__name__).setLevel(logging_level) + log.info(f'Logging level set to {logging_level}') + + +def convert_text_to_mds_from_args( + output_folder: str, + input_folder: str, + compression: str, + concat_tokens: int, + tokenizer_name: str, + bos_text: Optional[str], + eos_text: Optional[str], + use_tokenizer_eos: bool, + no_wrap: bool, + processes: int, + reprocess: bool, + trust_remote_code: bool, + logging_level: str, +) -> None: + """A wrapper for `convert_text_to_mds` to parse arguments. + + Args: + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + compression (str): The compression algorithm to use for MDS writing + concat_tokens (int): Concatenate up to this many tokens + tokenizer_name (str): The name of the tokenizer to use + bos_text (Optional[str]): The text to prepend to each example to separate concatenated examples + eos_text (Optional[str]): The text to append to each example to separate concatenated examples + use_tokenizer_eos (bool): Use the EOS text from the tokenizer + no_wrap (bool): Whether to let text examples wrap across multiple training examples + processes (int): The number of processes to use to download and convert the dataset + reprocess (bool): If true, reprocess the input_folder to MDS format. Otherwise, only reprocess upon changes to the input folder or dataset creation parameters. + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + logging_level (str): Logging level for the script. Default is INFO. + + Raises: + ValueError: If `use_tokenizer_eos` is True and `eos_text` is not None + """ + if use_tokenizer_eos: + # Ensure that eos text is not specified twice. + if eos_text is not None: + ValueError( + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', + ) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + eos_text = tokenizer.eos_token + + # now that we have validated them, change BOS/EOS to strings + if bos_text is None: + bos_text = '' + if eos_text is None: + eos_text = '' + _configure_logging(logging_level) + + # Define args for _args_str + args = { + 'tokenizer': tokenizer_name, + 'output_folder': output_folder, + 'input_folder': input_folder, + 'compression': compression, + 'concat_tokens': concat_tokens, + 'eos_text': eos_text, + 'bos_text': bos_text, + 'no_wrap': no_wrap, + 'processes': processes, + 'reprocess': reprocess, + 'trust_remote_code': trust_remote_code, + } + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=output_folder, + input_folder=input_folder, + concat_tokens=concat_tokens, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + compression=compression, + processes=processes, + reprocess=reprocess, + trust_remote_code=trust_remote_code, + args_str=str(args), + ) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py new file mode 100644 index 0000000000..bddd592dba --- /dev/null +++ b/llmfoundry/command_utils/eval.py @@ -0,0 +1,496 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from typing import Any, Dict, Optional, Tuple, Union + +import pandas as pd +import torch +from composer.core import Callback +from composer.loggers.logger_destination import LoggerDestination +from composer.trainer import Trainer +from composer.utils import dist, get_device, reproducibility +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from llmfoundry.utils import ( + find_mosaicml_logger, + log_eval_analytics, + maybe_create_mosaicml_logger, +) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_callback, + build_composer_model, + build_evaluators, + build_logger, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + EVAL_CONFIG_KEYS, + EvalConfig, + log_config, + make_dataclass_and_log_config, + process_init_device, +) +from llmfoundry.utils.registry_utils import import_file + +log = logging.getLogger(__name__) + + +def evaluate_model( + tokenizer: Dict[str, Any], + model_name: str, + model: Dict[str, Any], + dist_timeout: Union[float, int], + run_name: str, + seed: int, + icl_tasks: Union[str, list[Dict[str, Any]]], + max_seq_len: int, + device_eval_batch_size: Union[int, float], + eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], + eval_loader_config: Optional[Union[Dict[str, Any], list[Dict[str, Any]]]], + fsdp_config: Optional[Dict[str, Any]], + loggers: list[LoggerDestination], + python_log_level: Optional[str], + precision: str, + eval_gauntlet_df: Optional[pd.DataFrame], + eval_subset_num_batches: int, + icl_subset_num_batches: Optional[int], + callback_configs: Optional[Dict[str, Any]], + metadata: Optional[Dict[str, str]], + logged_config: Dict[str, Any], + should_log_config: bool = True, + load_path: Optional[str] = None, +): + log.info(f'Evaluating model: {model_name}') + # Build tokenizer and model + tokenizer_cfg = tokenizer + tokenizer_name = tokenizer_cfg['name'] + tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=max_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) + + # Callbacks + callbacks: list[Callback] = [ + build_callback(name=str(name), kwargs=callback_cfg) + for name, callback_cfg in callback_configs.items() + ] if callback_configs else [] + + if eval_gauntlet_callback is not None: + callbacks.append(eval_gauntlet_callback) + + if metadata is not None: + # Find the MosaicMLLogger + mosaicml_logger = find_mosaicml_logger(loggers) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(metadata) + mosaicml_logger._flush_metadata(force_flush=True) + + if fsdp_config and model.get('load_in_8bit', False): + raise ValueError( + 'The FSDP config block is not supported when loading ' + + 'Hugging Face models in 8bit.', + ) + + init_context = process_init_device(model, fsdp_config) + + name = model.pop('name') + composer_model = build_composer_model( + name=name, + tokenizer=tokenizer, + init_context=init_context, + cfg=model, + ) + + # Now add the eval metrics + if eval_loader_config is not None: + train_metrics = composer_model.get_metrics(is_train=True) + evaluators = add_metrics_to_eval_loaders( + evaluators, + list(train_metrics.keys()), + ) + + if eval_gauntlet_df is None and eval_gauntlet_callback is not None: + eval_gauntlet_df = pd.DataFrame( + columns=['model_name'] + list(eval_gauntlet_callback.averages) + + [t['name'] for t in eval_gauntlet_callback.categories], + ) + + if name == 'mpt_causal_lm' and load_path is None: + raise ValueError( + 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' + + + ' Please check your yaml and the model_cfg to ensure that load_path is set.', + ) + + assert composer_model is not None + + log.info(f'Building trainer for {model_name}...') + trainer = Trainer( + run_name=run_name, + seed=seed, + model=composer_model, + callbacks=callbacks, + loggers=loggers, + precision=precision, + fsdp_config=fsdp_config, + load_path=load_path, + load_weights_only=True, + progress_bar=False, + log_to_console=True, + dist_timeout=dist_timeout, + python_log_level=python_log_level, + ) + + if should_log_config: + log.info('Evaluation config:') + log_config(logged_config) + + log.info(f'Starting eval for {model_name}...') + if torch.cuda.is_available(): + torch.cuda.synchronize() + a = time.time() + trainer.eval( + eval_dataloader=evaluators, + subset_num_batches=eval_subset_num_batches, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + b = time.time() + + log.info(f'Ran {model_name} eval in: {b-a} seconds') + return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) + + +def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]: + """Transform the config to allow top-level keys for model configuration. + + This function allows users to use the 'train.py' syntax in 'eval.py'. + It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys + into the nested 'models' list format required by 'eval.py'. + + Input config format (train.py style): + ```yaml + model: + + load_path: /path/to/checkpoint + tokenizer: + + ``` + + Output config format (eval.py style): + ```yaml + models: + - model: + + tokenizer: + + load_path: /path/to/checkpoint + ``` + """ + if 'model' in cfg: + if 'models' in cfg: + raise ValueError( + 'Please specify either model or models in the config, not both', + ) + default_name = cfg.get('model').get('name') # type: ignore + model_cfg = { + 'model': cfg.pop('model'), + 'tokenizer': cfg.pop('tokenizer', None), + 'model_name': cfg.pop('model_name', default_name), + } + if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None: + raise ValueError( + 'When specifying model, "tokenizer" must be provided in the config', + ) + if 'load_path' in cfg: + model_cfg['load_path'] = cfg.pop('load_path') + cfg['models'] = [model_cfg] + + return cfg + + +def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]: + # Run user provided code if specified + for code_path in cfg.get('code_paths', []): + import_file(code_path) + + logged_cfg, eval_config = make_dataclass_and_log_config( + cfg, + EvalConfig, + EVAL_CONFIG_KEYS, + transforms=[allow_toplevel_keys], + icl_tasks_required=True, + ) + + model_configs = eval_config.models + eval_gauntlet_config = eval_config.eval_gauntlet or eval_config.eval_gauntlet_str + + fsdp_config = eval_config.fsdp_config + + # Mandatory Evaluation Parameters + icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str + if icl_tasks is None: + raise ValueError('icl_tasks must be specified in the config') + + # Optional Evaluation Parameters with default values + eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders + default_run_name: str = os.environ.get('RUN_NAME', 'llm') + run_name = eval_config.run_name if eval_config.run_name else default_run_name + + reproducibility.seed_all(eval_config.seed) + dist.initialize_dist(get_device(None), timeout=eval_config.dist_timeout) + + if eval_config.python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging.getLogger('llmfoundry').setLevel( + eval_config.python_log_level.upper(), + ) + + # Default argument values for evaluate_model + eval_gauntlet_df = None + models_df = None + composite_scores = None + trainers = [] + + # Build loggers + loggers: list[LoggerDestination] = [ + build_logger(name, logger_cfg) + for name, logger_cfg in (eval_config.loggers or {}).items() + ] + + mosaicml_logger = find_mosaicml_logger(loggers) + if mosaicml_logger is None: + mosaicml_logger = maybe_create_mosaicml_logger() + # mosaicml_logger will be None if run isn't on MosaicML platform + if mosaicml_logger is not None: + loggers.append(mosaicml_logger) + + # mosaicml_logger will be None if the run isn't from the MosaicML platform + if mosaicml_logger is not None: + log_eval_analytics( + mosaicml_logger, + model_configs, + icl_tasks, + eval_gauntlet_config, + ) + + for model_cfg in model_configs: + + attn_config = model_cfg['model'].get('attn_config', None) + if attn_config is not None: + seq_parallel_world_size = attn_config.get( + 'seq_parallel_world_size', + None, + ) + if seq_parallel_world_size is not None and seq_parallel_world_size != 1: + raise ValueError( + 'Offline eval does not support sequence parallelism.', + ) + + (trainer, logger_keys, eval_gauntlet_callback, + eval_gauntlet_df) = evaluate_model( + dist_timeout=eval_config.dist_timeout, + run_name=run_name, + seed=eval_config.seed, + icl_tasks=icl_tasks, + max_seq_len=eval_config.max_seq_len, + device_eval_batch_size=eval_config.device_eval_batch_size, + eval_gauntlet_config=eval_gauntlet_config, + eval_loader_config=eval_loader_config, + fsdp_config=fsdp_config, + loggers=loggers, + python_log_level=eval_config.python_log_level, + precision=eval_config.precision, + eval_gauntlet_df=eval_gauntlet_df, + callback_configs=eval_config.callbacks, + eval_subset_num_batches=eval_config.eval_subset_num_batches, + icl_subset_num_batches=eval_config.icl_subset_num_batches, + metadata=eval_config.metadata, + logged_config=logged_cfg, + should_log_config=eval_config.log_config, + **model_cfg, + ) + trainers.append(trainer) + + if eval_gauntlet_callback is not None: + composite_scores = eval_gauntlet_callback.eval_after_all( + trainer.state, + trainer.logger, + ) + + benchmark_to_taxonomy = {} + if eval_gauntlet_callback is not None: + for t in eval_gauntlet_callback.categories: + for b in t['benchmarks']: + benchmark_to_taxonomy[b['name']] = t['name'] + + assert 'model_name' in model_cfg, 'model_name must be specified in model config' + model_results = calculate_markdown_results( + logger_keys, + trainer, + benchmark_to_taxonomy, + model_cfg['model_name'], + ) + + if models_df is None: + models_df = model_results + else: + models_df = pd.concat([models_df, model_results], ignore_index=True) + + if eval_gauntlet_df is not None and eval_gauntlet_callback is not None: + assert composite_scores is not None + row = {'model_name': model_cfg['model_name']} + row.update({ + k.split('/')[-1]: v for k, v in composite_scores.items() + }) + eval_gauntlet_df = pd.concat([ + eval_gauntlet_df, + pd.DataFrame([row]), + ], + ignore_index=True) + + print(f'Printing gauntlet results for all models') + + print( + eval_gauntlet_df.sort_values( + list(eval_gauntlet_callback.averages.keys())[0], + ascending=False, + ).to_markdown(index=False), + ) + print(f'Printing complete results for all models') + assert models_df is not None + print(models_df.to_markdown(index=False)) + + trainer.close() + + return trainers, eval_gauntlet_df + + +def calculate_markdown_results( + logger_keys: list[str], + trainer: Trainer, + benchmark_to_taxonomy: Dict[str, str], + model_name: str, +): + results = {} + + for key in logger_keys: + # dl_name is either 2-tuple (benchmark_name, num_fewshot) + # or 3-tuple (benchmark_name, num_fewshot, subcategory) + dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1] + if 'Accuracy' not in metric_name: + continue + + metric = trainer.state.eval_metrics.get('/'.join(dl_name), + {}).get(metric_name, None) + + if metric is None: + continue + if dl_name[1] not in results: + results[dl_name[1]] = {} + + if dl_name[0] not in results[dl_name[1]]: + results[dl_name[1]][dl_name[0]] = {} + + if metric_name not in results[dl_name[1]][dl_name[0]]: + results[dl_name[1]][dl_name[0]][metric_name] = [] + + results[dl_name[1]][dl_name[0]][metric_name].append({ + 'val': metric.compute(), + 'subcat': dl_name[-1] if len(dl_name) == 3 else 'no_subcat', + }) + + df = pd.DataFrame( + columns=[ + 'Category', + 'Benchmark', + 'Subtask', + 'Accuracy', + 'Number few shot', + 'Model', + ], + ) + + for num_shot in results: + for benchmark in results[num_shot]: + for metric in results[num_shot][benchmark]: + subscores = results[num_shot][benchmark][metric] + if len(subscores) == 1: + row = { + 'Category': benchmark_to_taxonomy.get(benchmark, ''), + 'Benchmark': benchmark, + 'Subtask': None, + 'Accuracy': subscores[0]['val'], + 'Number few shot': num_shot, + 'Model': model_name, + } + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) + else: + row = { + 'Category': + benchmark_to_taxonomy.get(benchmark, ''), + 'Benchmark': + benchmark, + 'Subtask': + 'Average', + 'Accuracy': + sum(s['val'] for s in subscores) / len(subscores), + 'Number few shot': + num_shot, + 'Model': + model_name, + } + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) + for sub in subscores: + row = { + 'Category': + benchmark_to_taxonomy.get(benchmark, ''), + 'Benchmark': + None, + 'Subtask': + sub['subcat'], + 'Accuracy': + sub['val'], + 'Number few shot': + num_shot, + 'Model': + model_name, + } + df = pd.concat([df, pd.DataFrame([row])], + ignore_index=True) + return df + + +def eval_from_yaml( + yaml_path: str, + args_list: Optional[list[str]], +) -> Tuple[list[Trainer], pd.DataFrame]: + """Run the evaluation with optional overrides from CLI.""" + # Load yaml and CLI arguments. + om.clear_resolver('oc.env') + with open(yaml_path) as f: + yaml_cfg = om.load(f) + if args_list: + cli_cfg = om.from_cli(args_list) + yaml_cfg = om.merge(yaml_cfg, cli_cfg) + assert isinstance(yaml_cfg, DictConfig) + return evaluate(yaml_cfg) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py new file mode 100644 index 0000000000..c925e6e586 --- /dev/null +++ b/llmfoundry/command_utils/train.py @@ -0,0 +1,605 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import gc +import logging +import os +import time +import warnings +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed +from composer import ComposerModel, Trainer +from composer.core.callback import Callback +from composer.profiler import ( + JSONTraceHandler, + Profiler, + TraceHandler, + cyclic_schedule, +) +from composer.utils import dist, get_device, reproducibility +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from llmfoundry.callbacks import AsyncEval, HuggingFaceCheckpointer +from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.eval.metrics.nlp import InContextLearningMetric +from llmfoundry.layers_registry import ffns_with_megablocks +from llmfoundry.utils import ( + find_mosaicml_logger, + log_train_analytics, + maybe_create_mosaicml_logger, +) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_algorithm, + build_callback, + build_composer_model, + build_evaluators, + build_load_planner, + build_logger, + build_optimizer, + build_save_planner, + build_scheduler, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + TRAIN_CONFIG_KEYS, + TrainConfig, + log_config, + log_dataset_uri, + make_dataclass_and_log_config, + pop_config, + process_init_device, +) +from llmfoundry.utils.exceptions import ( + BaseContextualError, + EvalDataLoaderLocation, + TrainDataLoaderLocation, +) +from llmfoundry.utils.registry_utils import import_file + +log = logging.getLogger(__name__) + + +def validate_config(train_config: TrainConfig): + """Validates compatible model and dataloader selection.""" + # Validate the rest of the config + loaders = [train_config.train_loader] + if train_config.eval_loaders is not None: + for loader in (train_config.eval_loaders or []): # pyright + if 'label' not in loader or loader['label'] is None: + raise ValueError( + 'When specifying multiple evaluation datasets, each one must include the \ + `label` attribute.', + ) + loaders.append(loader) + if train_config.eval_loader is not None: + loaders.append(train_config.eval_loader) + for loader in loaders: + if loader['name'] == 'text': + if train_config.model['name'] == 'hf_t5': + raise ValueError( + f'Model type "{train_config.model["name"]}" is not supported when using the "text " ' +\ + f'dataloader. Only finetuning is supported.') + + if train_config.icl_tasks is not None or train_config.icl_tasks_str is not None: + if train_config.model['name'] == 'hf_t5': + raise ValueError( + 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".', + ) + + if ( + train_config.model.get('fc_type', 'torch') != 'te' and + 'te' not in train_config.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') and + 'fp8' in train_config.precision + ): + warnings.warn( + "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " + + + "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision.", + ) + + if ( + train_config.model.get('fc_type', 'torch') == 'te' or 'te' + in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') + ): + fsdp_config = train_config.fsdp_config + act_ckpt = fsdp_config.get( + 'activation_checkpointing', + False, + ) if fsdp_config else False + act_ckpt_reentrant = fsdp_config.get( + 'activation_checkpointing_reentrant', + False, + ) if fsdp_config else False + if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: + warnings.warn( + '`te.Linear` layers do not support activation_checkpointing with ' + + '`activation_checkpointing_reentrant = True`. ' + + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.', + ) + assert train_config.fsdp_config is not None # pyright (this is known because fsdp_config is not None) + train_config.fsdp_config['activation_checkpointing_reentrant' + ] = False + + if train_config.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': + warnings.warn( + '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.', + ) + torch._dynamo.config.suppress_errors = True # type: ignore (third-party) + + if train_config.model.get('load_in_8bit', False): + raise ValueError( + '`load_in_8bit` is only supported for evaluation rather than training.', + ) + + if train_config.model.get('ffn_config', {}).get( + 'ffn_type', + 'mptmlp', + ) in ffns_with_megablocks: + moe_world_size = train_config.model.get('ffn_config', + {}).get('moe_world_size', 1) + use_orig_params = train_config.fsdp_config.get( + 'use_orig_params', + True, + ) if train_config.fsdp_config is not None else True + if moe_world_size > 1 and not use_orig_params: + raise ValueError( + f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.', + ) + + +def _log_num_params(model: ComposerModel, logged_cfg: Dict[str, Any]): + # Log number of parameters + if hasattr(model, 'n_total_params'): + n_params = model.n_total_params + n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. + else: + n_params = sum(p.numel() for p in model.parameters()) + n_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + if hasattr(model, 'n_active_params'): + n_active_params = model.n_active_params + else: + n_active_params = n_params + logged_cfg.update({ + 'n_params': n_params, + 'n_active_params': n_active_params, + 'n_trainable_params': n_trainable_params, + }) + + +def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): + """Initialize distributed and test setup with a barrier. + + Args: + dist_timeout (Union[int, float]): Timeout for initializing the process group + """ + log.debug('Initializing dist with device...') + dist.initialize_dist(get_device(None), timeout=dist_timeout) + log.debug('Testing barrier with device...') + dist.barrier() + log.debug('Barrier test passed with device.') + + +def train(cfg: DictConfig) -> Trainer: + code_paths = cfg.get('code_paths', []) + # Import any user provided code + for code_path in code_paths: + import_file(code_path) + + logged_cfg, train_cfg = make_dataclass_and_log_config( + cfg, + TrainConfig, + TRAIN_CONFIG_KEYS, + transforms='all', + ) + + # Set logging level + if train_cfg.python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging.getLogger('llmfoundry').setLevel( + train_cfg.python_log_level.upper(), + ) # Foundry module + logging.getLogger(__name__).setLevel( + train_cfg.python_log_level.upper(), + ) # Train script + + _initialize_dist_with_barrier(dist_timeout=train_cfg.dist_timeout) + + # Filter deprecation warning from torch internal usage + warnings.filterwarnings( + action='ignore', + category=UserWarning, + message= + 'torch.distributed.*_base is a private function and will be deprecated.*', + ) + + # Check for incompatibilities between the model and data loaders + validate_config(train_cfg) + + cuda_alloc_conf = [] + # Get max split size mb + max_split_size_mb: Optional[int] = train_cfg.max_split_size_mb + if max_split_size_mb is not None: + cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') + + # Expandable segments + if train_cfg.expandable_segments: + cuda_alloc_conf.append('expandable_segments:True') + + if len(cuda_alloc_conf) > 0: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) + + # Set CUDA lazy loading + # This can save a bit of memory if not all modules are needed + cuda_load_lazy: bool = train_cfg.cuda_load_lazy + if cuda_load_lazy: + os.environ['CUDA_MODULE_LOADING'] = 'LAZY' + + # Set seed first + seed: int = train_cfg.seed + reproducibility.seed_all(seed) + + # Mandatory model training configs + model_config = train_cfg.model + train_loader_config = train_cfg.train_loader + + # Optional fsdp data, fine-tuning, and eval configs + fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + + if fsdp_config is not None: + if 'load_planner' in fsdp_config: + load_planners = list(fsdp_config['load_planner'].items()) + if len(load_planners) > 1: + raise ValueError( + 'Only one load planner can be specified in the config.', + ) + load_planner_name, load_planner_config = load_planners[0] + fsdp_config['load_planner'] = build_load_planner( + load_planner_name, + **load_planner_config, + ) + + if 'save_planner' in fsdp_config: + save_planners = list(fsdp_config['save_planner'].items()) + if len(save_planners) > 1: + raise ValueError( + 'Only one save planner can be specified in the config.', + ) + save_planner_name, save_planner_config = save_planners[0] + fsdp_config['save_planner'] = build_save_planner( + save_planner_name, + **save_planner_config, + ) + + eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders + icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str + eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str + + # Optional parameters will be set to default values if not specified. + default_run_name: str = os.environ.get('RUN_NAME', 'llm') + run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name + is_state_dict_sharded: bool = ( + fsdp_config.get('state_dict_type', 'full') == 'sharded' + ) if fsdp_config else False + save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt' + save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' + + # Enable autoresume from model checkpoints if possible + autoresume_default: bool = False + if logged_cfg.get('run_name', None) is not None \ + and train_cfg.save_folder is not None \ + and not train_cfg.save_overwrite \ + and not train_cfg.save_weights_only: + autoresume_default = True + + if not train_cfg.autoresume and autoresume_default: + log.info( + 'As run_name, save_folder, and save_latest_filename are set, \ + changing autoresume default to True...', + ) + + # Warn if fsdp is enabled but user only has 1 GPU + if dist.get_world_size() == 1 and fsdp_config is not None: + warnings.warn( + 'FSDP is not applicable for single-GPU training. Reverting to DDP.', + ) + fsdp_config = None + + # Initialize context + init_context = process_init_device(model_config, fsdp_config) + logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) + + # Build tokenizer + log.info('Building tokenizer...') + tokenizer_name = train_cfg.tokenizer['name'] + tokenizer_kwargs = train_cfg.tokenizer.get('kwargs', {}) + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + # Scheduler + scheduler_name: str = train_cfg.scheduler.pop('name') + scheduler = build_scheduler(scheduler_name, train_cfg.scheduler) + + # Loggers + loggers = [ + build_logger(str(name), logger_cfg) + for name, logger_cfg in train_cfg.loggers.items() + ] if train_cfg.loggers else [] + + mosaicml_logger = find_mosaicml_logger(loggers) + if mosaicml_logger is None: + mosaicml_logger = maybe_create_mosaicml_logger() + if mosaicml_logger is not None: + # mosaicml_logger will be None if run isn't on MosaicML platform + loggers.append(mosaicml_logger) + + if train_cfg.metadata is not None: + # Optionally flatten the metadata for logging + if train_cfg.flatten_metadata: + logged_cfg.pop('metadata', None) + common_keys = set( + logged_cfg.keys(), + ) & set(train_cfg.metadata.keys()) + if len(common_keys) > 0: + raise ValueError( + f'Keys {common_keys} are already present in the config. Please rename them in metadata ' + + + 'or set flatten_metadata=False to avoid flattening the metadata in the logged config.', + ) + + logged_cfg.update(train_cfg.metadata, merge=True) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(train_cfg.metadata) + mosaicml_logger._flush_metadata(force_flush=True) + + # Profiling + profiler: Optional[Profiler] = None + profiler_cfg = train_cfg.profiler + if profiler_cfg: + profiler_schedule_cfg: Dict = pop_config( + profiler_cfg, + 'schedule', + must_exist=True, + ) + profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) + # Only support json trace handler + profiler_trace_handlers: List[TraceHandler] = [] + profiler_trace_cfg: Optional[Dict] = pop_config( + profiler_cfg, + 'json_trace_handler', + must_exist=False, + default_value=None, + ) + if profiler_trace_cfg: + profiler_trace_handlers.append( + JSONTraceHandler(**profiler_trace_cfg), + ) + profiler = Profiler( + **profiler_cfg, + trace_handlers=profiler_trace_handlers, + schedule=profiler_schedule, + ) + + callback_configs = train_cfg.callbacks or {} + + # Callbacks + callbacks: List[Callback] = [ + build_callback( + name=str(name), + kwargs=callback_cfg, + train_config=logged_cfg, + ) for name, callback_cfg in callback_configs.items() + ] + + use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) + + algorithm_configs = train_cfg.algorithms or {} + + # Algorithms + algorithms = [ + build_algorithm(str(name), algorithm_cfg) + for name, algorithm_cfg in algorithm_configs.items() + ] + + # Dataloaders + log.info('Building train loader...') + try: + train_loader = build_dataloader( + train_loader_config, + tokenizer, + train_cfg.device_train_batch_size, + ) + except BaseContextualError as e: + e.location = TrainDataLoaderLocation + raise e + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics({'data_validated': time.time()}) + + ## Evaluation + if use_async_eval: + evaluators = [] + if train_cfg.eval_first: + warnings.warn( + 'AsyncEval callback does not support eval_first=True. Ignoring.', + ) + train_cfg.eval_first = False + + else: + try: + log.info('Building eval loader...') + eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=train_cfg.device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=train_cfg.icl_subset_num_batches, + ) + if eval_gauntlet_callback is not None: + callbacks.append(eval_gauntlet_callback) + except BaseContextualError as e: + e.location = EvalDataLoaderLocation + raise e + + if mosaicml_logger is not None: + log_train_analytics( + mosaicml_logger, + model_config, + train_loader_config, + eval_loader_config, + train_cfg.callbacks, + tokenizer_name, + train_cfg.load_path, + icl_tasks_config, + eval_gauntlet_config, + ) + # Build Model + log.info('Initializing model...') + name = model_config.pop('name') + assert isinstance(name, str) + assert isinstance(model_config, dict) + model = build_composer_model( + name=name, + tokenizer=tokenizer, + init_context=init_context, + master_weights_dtype=model_config.get('master_weights_dtype', None), + cfg=model_config, + ) + + _log_num_params(model, logged_cfg) + + # Optimizer + optimizer_name: str = train_cfg.optimizer.pop('name') + optimizer_cfg = train_cfg.optimizer + optimizer = build_optimizer(model, optimizer_name, optimizer_cfg) + + # Now add the eval metrics + try: + if eval_loader_config is not None and not use_async_eval: + eval_metrics = model.get_metrics(is_train=False) + non_icl_metrics = [ + metric_name for metric_name, metric in eval_metrics.items() + if not isinstance(metric, InContextLearningMetric) + ] + evaluators = add_metrics_to_eval_loaders( + evaluators, + non_icl_metrics, + ) + except BaseContextualError as e: + e.location = EvalDataLoaderLocation + raise e + + compile_config = train_cfg.compile_config + # Build the Trainer + log.info('Building trainer...') + trainer = Trainer( + run_name=run_name, + seed=seed, + model=model, + train_dataloader=train_loader, + eval_dataloader=evaluators, + optimizers=optimizer, + schedulers=scheduler, + max_duration=train_cfg.max_duration, + eval_interval=train_cfg.eval_interval, + eval_subset_num_batches=train_cfg.eval_subset_num_batches, + progress_bar=train_cfg.progress_bar, + log_to_console=train_cfg.log_to_console, + console_log_interval=train_cfg.console_log_interval, + loggers=loggers, + callbacks=callbacks, + precision=train_cfg.precision, + algorithms=algorithms, + device_train_microbatch_size=train_cfg.device_train_microbatch_size, + parallelism_config={'fsdp': fsdp_config}, + save_folder=train_cfg.save_folder, + save_filename=save_filename, + save_latest_filename=save_latest_filename, + save_interval=train_cfg.save_interval, + save_num_checkpoints_to_keep=train_cfg.save_num_checkpoints_to_keep, + save_overwrite=train_cfg.save_overwrite, + save_weights_only=train_cfg.save_weights_only, + load_path=train_cfg.load_path, + load_weights_only=train_cfg.load_weights_only, + load_strict_model_weights=train_cfg.load_strict_model_weights, + load_ignore_keys=train_cfg.load_ignore_keys, + save_ignore_keys=train_cfg.save_ignore_keys, + autoresume=train_cfg.autoresume, + python_log_level=train_cfg.python_log_level, + dist_timeout=train_cfg.dist_timeout, + profiler=profiler, + compile_config=compile_config, + spin_dataloaders=train_cfg.spin_dataloaders, + ) + + # Optionally just save an HF checkpoint + if train_cfg.only_hf_checkpoint: + hf_checkpointer_callbacks = [ + c for c in callbacks if isinstance(c, HuggingFaceCheckpointer) + ] + if len(hf_checkpointer_callbacks) == 0: + raise ValueError( + 'No HuggingFaceCheckpointer callback found, but only_hf_checkpoint was set to True. Please add a HuggingFaceCheckpointer.', + ) + if len(hf_checkpointer_callbacks) > 1: + raise ValueError( + 'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.', + ) + + hf_checkpointer_callback = hf_checkpointer_callbacks[0] + hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) + return trainer + + if train_cfg.only_composer_checkpoint: + log.info('Not training. Only saving composer checkpoint.') + trainer.save_checkpoint_to_save_folder() + log.info('Done saving checkpoint.') + return trainer + + if train_cfg.log_config: + log.info('Logging config') + log_config(logged_cfg) + log_dataset_uri(logged_cfg) + torch.cuda.empty_cache() + gc.collect() + + # Eval first if requested + if train_cfg.eval_first and trainer.state.timestamp.batch.value == 0: + trainer.eval() + + log.info('Starting training...') + trainer.fit() + + log.info('Done.') + return trainer + + +def train_from_yaml( + yaml_path: str, + args_list: Optional[List[str]] = None, +) -> Trainer: + """Run the training with optional overrides from CLI.""" + # Load yaml and CLI arguments. + om.clear_resolver('oc.env') + with open(yaml_path) as f: + yaml_cfg = om.load(f) + if args_list: + cli_cfg = om.from_cli(args_list) + yaml_cfg = om.merge(yaml_cfg, cli_cfg) + assert isinstance(yaml_cfg, DictConfig) + return train(yaml_cfg) diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index 966ca90c86..5710be0c55 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -1,7 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.data.data import ( + SUPPORTED_MDS_ENCODING_TYPES, + ConcatTokensDataset, + NoConcatDataset, + stream_remote_local_validate, +) from llmfoundry.data.dataloader import build_dataloader from llmfoundry.data.finetuning import ( Seq2SeqFinetuningCollator, @@ -55,4 +60,6 @@ 'auto_packing_ratio', 'profile_packing', 'ConcatenatedSequenceCollatorWrapper', + 'stream_remote_local_validate', + 'SUPPORTED_MDS_ENCODING_TYPES', ] diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index 04eb6d345d..bde68a6998 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -5,16 +5,31 @@ import os import warnings from abc import ABC, abstractmethod -from typing import Dict, Iterable, Union +from typing import Dict, Iterable, Optional, Union import datasets as hf_datasets import numpy as np +from numpy.typing import NDArray from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase __all__ = [ + 'AbstractConcatTokensDataset', 'ConcatTokensDataset', 'NoConcatDataset', + 'stream_remote_local_validate', + 'SUPPORTED_MDS_ENCODING_TYPES', +] + +SUPPORTED_MDS_ENCODING_TYPES = [ + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + 'uint16', + 'uint32', + 'uint64', ] @@ -97,14 +112,14 @@ def __init__( ) @abstractmethod - def __iter__(self) -> Iterable[Dict[str, bytes]]: + def __iter__(self) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]: pass class ConcatTokensDataset(AbstractConcatTokensDataset): """An IterableDataset that returns token samples for MDSWriter. - Returns dicts of {'tokens': bytes} + Returns dicts of {'tokens': ndarray:int32} To use data created by this class and written to MDS format: @@ -119,7 +134,7 @@ class ConcatTokensDataset(AbstractConcatTokensDataset): # note, you need to copy the numpy array because the original is non-writeable # and torch does not support non-writeable tensors, so you get a scary warning and # if you do try to write to the tensor you get undefined behavior - tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy()) + tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int32).copy()) print(tokenizer.decode(tokens)) ``` """ @@ -136,7 +151,7 @@ def __init__( self.hf_dataset = hf_dataset super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) - def __iter__(self) -> Iterable[Dict[str, bytes]]: + def __iter__(self) -> Iterable[Dict[str, NDArray]]: buffer = [] for sample in self.hf_dataset: encoded = self.tokenizer( @@ -150,6 +165,27 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: concat_sample = buffer[:self.max_length] buffer = buffer[self.max_length:] if self.should_wrap else [] yield { - # convert to bytes to store in MDS binary format - 'tokens': np.asarray(concat_sample).tobytes(), + # convert to ndarray to store in MDS format + 'tokens': np.asarray(concat_sample, dtype=np.int32), } + + +def stream_remote_local_validate( + remote: Optional[str], + local: Optional[str], + split: Optional[str], +): + """Check that, if needed, the local/split directory exists. + + Args: + remote (Optional[str]): Remote path to the dataset. + local (Optional[str]): Local path to the dataset. + split (Optional[str]): Subdirectory specifying which dataset split to use, if any. + """ + if remote is None or (local == remote): + if local is not None and os.path.isdir(local): + contents = set(os.listdir(local)) + if split is not None and split not in contents: + raise ValueError( + f'Local directory {local} does not contain split {split}', + ) diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 83a9a7d8ea..e7521bc343 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -3,7 +3,7 @@ """Dataloader builder utilities.""" -from typing import Any, Dict +from typing import Any, Dict, Union from composer import DataSpec from transformers import PreTrainedTokenizerBase @@ -19,7 +19,7 @@ def build_dataloader( cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, + device_batch_size: Union[int, float], ) -> DataSpec: """Builds a dataloader from a config. diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 639beba6f0..771033a703 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import inspect import logging import os from typing import Any, Dict, Optional, Tuple, Union @@ -17,6 +18,8 @@ validate_target_settings, ) from llmfoundry.data.finetuning.tasks import ( + DEFAULT_TARGET_PROMPTS, + DEFAULT_TARGET_RESPONSES, DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor, @@ -39,14 +42,20 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -# Default settings to use for target responses and target prompts -_DEFAULT_TARGET_RESPONSES = 'last' -_DEFAULT_TARGET_PROMPTS = 'none' +# Extra keys present in the dataset config dictionary beyond the constructor keys +_ALLOWED_DATASET_KEYS = { + 'shuffle', + 'packing_ratio', + 'allow_pad_trimming', + 'seq_parallel_replication', + 'auto_packing_replication', + 'max_leftover_bins_to_keep', +} def build_finetuning_dataloader( tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, + device_batch_size: Union[int, float], dataset: Dict[str, Any], num_workers: int, drop_last: bool = False, @@ -64,9 +73,12 @@ def build_finetuning_dataloader( on which you intend to use, as explained below. Args: - name (str): The type of dataloader to build. Must = "finetuning". - --- - *** HuggingFace dataset config fields *** + tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to + prepare the data from raw text. Any missing sentinel tokens will + be added by the collator. + device_batch_size (int, float): The size of the batches (number of examples) + that the dataloader will produce. + dataset (Dict[str, Any]): A HuggingFace dataset config which contains the following fields: dataset.hf_name (str, optional): The name of the HuggingFace dataset to use. Can also be a remote http(s) directory or object store bucket containing the file {split}.jsonl in the format (prompt, response), @@ -130,16 +142,32 @@ def build_finetuning_dataloader( The script `scripts/misc/profile_packing.py` can help you choose the best packing_ratio. dataset.shuffle (bool): Whether to shuffle the dataset. - ___ See :class:`StreamingFinetuningDataset` for info on other standard config options within `dataset` that will be passed as kwargs if using the streaming codepath. - --- - tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to - prepare the data from raw text. Any missing sentinel tokens will - be added by the collator. - device_batch_size (int, float): The size of the batches (number of examples) - that the dataloader will produce. + num_workers (int, optional): How many subprocesses to use for data loading. + 0 means that the data will be loaded in the main process. The default is 0. + This argument is passed directly to the pytorch :class:`DataLoader`. + drop_last (bool, optional): If true, drop the last incomplete batch, if the dataset + size is not divisible by the batch size. If False and the size of dataset is + not divisible by the batch size, then the last batch will be smaller. The + default is False. This argument is passed directly to the pytorch :class:`DataLoader`. + pin_memory (bool, optional): If True, the data loader will copy Tensors into device/CUDA + pinned memory before returning them. If your data elements are a custom type, or your + `collate_fn` returns a batch that is a custom type. This argument is passed directly to + the pytorch :class:`DataLoader`. + prefetch_factor (int, optional): Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + (default value depends on the set value for num_workers. If value of num_workers=0 default + is None. Otherwise, if value of num_workers > 0 default is 2). This argument is passed + directly to the pytorch :class:`DataLoader`. + persistent_workers (bool, optional): If True, the data loader will not shut down the worker + processes after a dataset has been consumed once. This allows to maintain the workers + Dataset instances alive. The default is False. This argument is passed directly to the + pytorch :class:`DataLoader`. + timeout (int, optional): If positive, the timeout value for collecting a batch from workers. + Should always be non-negative. The default is 0. This argument is passed directly to the + pytorch :class:`DataLoader`. See :class:`DataLoader` for standard argument options to the pytorch dataloader, such as `drop_last`, `num_workers`, etc. @@ -152,7 +180,26 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset - _validate_config(**dataset_cfg) + is_streaming = ( + dataset_cfg.get('remote') is not None or + dataset_cfg.get('streams') is not None + ) + if is_streaming: + dataset_constructor_keys = inspect.signature( + dataset_constructor.streaming_dataset_class, + ).parameters.keys() + else: + dataset_constructor_keys = inspect.signature( + dataset_constructor.build_from_hf, + ).parameters.keys() + + allowed_dataset_config_keys = set( + dataset_constructor_keys, + ).union(_ALLOWED_DATASET_KEYS) + _validate_config( + **dataset_cfg, + allowed_dataset_keys=allowed_dataset_config_keys, + ) # Use EOS as the pad token if none exists if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok) @@ -194,9 +241,7 @@ def build_finetuning_dataloader( streaming_dataset = None # for pyright sampler = None - if dataset_cfg.get( - 'remote', - ) is not None or dataset_cfg.get('streams') is not None: + if is_streaming: # Build streaming dataloader streams_cfg = dataset_cfg.get('streams', None) streams_cfg = to_dict_container( @@ -206,33 +251,20 @@ def build_finetuning_dataloader( streams_cfg, ) if streams_cfg is not None else None - # note: we don't need to use ** here because we're setting default values for almost all arguments + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'streams', 'packing_ratio'} + } streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, - local=dataset_cfg.get('local', None), - remote=dataset_cfg.get('remote', None), - split=dataset_cfg.get('split', None), - download_retry=dataset_cfg.get('download_retry', 2), - download_timeout=dataset_cfg.get('download_timeout', 60), - validate_hash=dataset_cfg.get('validate_hash', None), - keep_zip=dataset_cfg.get('keep_zip', False), - epoch_size=dataset_cfg.get('epoch_size', None), - predownload=dataset_cfg.get('predownload', None), - cache_limit=dataset_cfg.get('cache_limit', None), - partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), - num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), - batch_size=dataset_batch_size, - shuffle=dataset_cfg.get('shuffle', False), - shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), - shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), - shuffle_block_size=dataset_cfg.get('shuffle_block_size', None), - sampling_method=dataset_cfg.get('sampling_method', 'balanced'), - sampling_granularity=dataset_cfg.get('sampling_granularity', 1), - batching_method=dataset_cfg.get('batching_method', 'random'), - max_seq_len=dataset_cfg['max_seq_len'], - allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), + batch_size=dataloader_batch_size, replication=replication_factor, + packing_ratio=dataloader_batch_size / dataset_batch_size, + **dataset_constructor_args, ) else: @@ -263,24 +295,19 @@ def build_finetuning_dataloader( dataset_name_or_path, ) - # Build dataset from HF. + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'split', 'preprocessing_fn'} + } streaming_dataset = dataset_constructor.build_from_hf( dataset_name=dataset_name_or_path, split=split, - safe_load=dataset_cfg.get('safe_load', False), - max_seq_len=dataset_cfg['max_seq_len'], preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - target_prompts=dataset_cfg.get( - 'target_prompts', - _DEFAULT_TARGET_PROMPTS, - ), - target_responses=dataset_cfg.get( - 'target_responses', - _DEFAULT_TARGET_RESPONSES, - ), - decoder_only_format=dataset_cfg['decoder_only_format'], - hf_kwargs=dataset_cfg.get('hf_kwargs', {}), + **dataset_constructor_args, ) # Ensure dataset is large enough. @@ -347,6 +374,7 @@ def _validate_config( streams: Optional[Dict[str, Any]] = None, target_prompts: Optional[str] = None, target_responses: Optional[str] = None, + allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS, **kwargs: Dict[str, Any], ) -> None: """Validates the dataset configuration. @@ -356,45 +384,59 @@ def _validate_config( the other. Args: - dataset_cfg (DictConfig): The dataset configuration to be validated. + max_seq_len (int): The maximum length of sequences + in the batch. See :class:`Seq2SeqFinetuningCollator` docstring + for details. + decoder_only_format (bool): Whether to format the + examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` + docstring for details. + hf_name (str, optional): The name of the HuggingFace dataset + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. + local (str, optional): Local path where remote data + will be streamed to. Only valid if `cfg.dataset.remote` has + also been set. + remote (str, optional): Location of a MDS-formatted + streaming dataset to use. Setting this will tell the builder + to create a streaming dataset rather than a HuggingFace dataset. + hf_kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. + preprocessing_fn (str, optional): The name/import path of + the preprocessing function to use for formatting the data examples. + If ``None`` (default), the builder will use the preprocessing function + registered under `hf_name` (see `tasks.py`), if one exists, + otherwise it will skip preprocessing. + If `preprocessing_fn` corresponds to a registered preprocessing + function in `tasks.py`, the builder will use that. + Otherwise, it will interpret `preprocessing_fn` as a + "import.path:function_name" import path; e.g., it will call + `from import.path import function_name` and use the imported + function as the preprocessing function. + safe_load (bool, optional): Whether to enforce safe loading of the dataset. + If `None`, will default to not applying any safe loading. + streams (Dict[str, Any], optional): A dictionary with multiple data streams. + If `None`, will assume no streams. + target_prompts (str): Which prompts are used as training targets. + Defaults to "none", meaning prompts are never used as training targets. + See :class:`Seq2SeqFinetuningCollator` docstring for details. + target_responses (str): Which responses are used as training targets. + Defaults to "last", meaning only the final response in multi-turn examples + will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for + details. + allowed_dataset_keys (set[str], optional): The set of allowed keys for the dataset config. + kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. Raises: ValueError: If the dataset configuration does not meet the requirements. """ - # Check for extraneous keys in the dataset config - allowed_additional_kwargs = { - 'local', - 'remote', - 'split', - 'download_retry', - 'download_timeout', - 'validate_hash', - 'keep_zip', - 'epoch_size', - 'predownload', - 'cache_limit', - 'partition_algo', - 'num_canonical_nodes', - 'batch_size', - 'shuffle', - 'shuffle_algo', - 'shuffle_seed', - 'shuffle_block_size', - 'sampling_method', - 'sampling_granularity', - 'batching_method', - 'max_seq_len', - 'allow_unsafe_types', - 'replication', - 'packing_ratio', - 'allow_pad_trimming', - 'seq_parallel_replication', - 'auto_packing_replication', - } - if not set(kwargs.keys()).issubset(allowed_additional_kwargs): + if not set(kwargs.keys()).issubset(allowed_dataset_keys): raise ValueError( 'The dataset config contains the following extraneous keys: ' +\ - ', '.join(set(kwargs.keys()) - allowed_additional_kwargs), + ', '.join(set(kwargs.keys()) - allowed_dataset_keys), ) if hf_name is not None: @@ -478,9 +520,9 @@ def _validate_config( # Raise an error if the target_prompts + target_responses + decoder_only_format settings # are invalid if target_prompts is None: - target_prompts = _DEFAULT_TARGET_PROMPTS + target_prompts = DEFAULT_TARGET_PROMPTS if target_responses is None: - target_responses = _DEFAULT_TARGET_RESPONSES + target_responses = DEFAULT_TARGET_RESPONSES target_prompts, target_responses = target_prompts.lower( ), target_responses.lower() validate_target_settings( @@ -502,7 +544,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: completed, the function removes the signal file. Args: - hf_name (str): The path of the HuggingFace dataset to download. + remote_path (str): The path of the HuggingFace dataset to download. split (str): The dataset split to download (e.g., 'train', 'validation', 'test'). Returns: @@ -582,9 +624,9 @@ def build_collate_fn( dataset_cfg = dataloader_cfg['dataset'] target_responses = dataset_cfg.get( 'target_responses', - _DEFAULT_TARGET_RESPONSES, + DEFAULT_TARGET_RESPONSES, ) - target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS) + target_prompts = dataset_cfg.get('target_prompts', DEFAULT_TARGET_PROMPTS) max_seq_len = dataset_cfg['max_seq_len'] decoder_only_format = dataset_cfg['decoder_only_format'] allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 96ca17f5f4..d32c248416 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -47,6 +47,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -59,6 +60,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from streaming import Stream, StreamingDataset from transformers import PreTrainedTokenizerBase +from llmfoundry.data import ( + SUPPORTED_MDS_ENCODING_TYPES, + stream_remote_local_validate, +) from llmfoundry.data.finetuning.collator import ( _HF_IGNORE_INDEX, stitch_turns_decoder_only, @@ -110,6 +115,9 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ), ) SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] +HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata'] +DEFAULT_TARGET_RESPONSES = 'last' +DEFAULT_TARGET_PROMPTS = 'none' PromptResponseDict = Mapping[str, str] ChatFormattedDict = Mapping[str, List[Dict[str, str]]] @@ -157,7 +165,7 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: Args: dirpath (str): Directory path to check. - Returns + Returns: True if directory is empty or non-existent. False otherwise. """ return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 @@ -494,26 +502,15 @@ def is_valid_ift_example( return True -def _stream_remote_local_validate( - remote: Optional[str], - local: Optional[str], - split: Optional[str], -): - if remote is None or (local == remote): - if local is not None and os.path.isdir(local): - contents = set(os.listdir(local)) - if split is not None and split not in contents: - raise ValueError( - f'Local directory {local} does not contain split {split}', - ) - - class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. Args: tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to tokenize samples. + token_encoding_type (str): The encoding type of the tokenized samples. This is only used + for legacy datasets that have been written directly as 'bytes' instead of numpy + arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'. streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -574,6 +571,7 @@ class StreamingFinetuningDataset(StreamingDataset): def __init__( self, tokenizer: PreTrainedTokenizerBase, + token_encoding_type: str = 'int64', streams: Optional[Sequence[Stream]] = None, local: Optional[str] = None, remote: Optional[str] = None, @@ -598,6 +596,7 @@ def __init__( max_seq_len: int = 2048, allow_unsafe_types: bool = False, replication: Optional[int] = None, + packing_ratio: Optional[float] = None, **kwargs: Any, ): @@ -606,11 +605,17 @@ def __init__( f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}', ) + if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: + raise ValueError( + f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', + ) + self.token_encoding_type = token_encoding_type + if streams is None: - _stream_remote_local_validate(remote, local, split) + stream_remote_local_validate(remote, local, split) else: for stream in streams: - _stream_remote_local_validate( + stream_remote_local_validate( stream.remote, stream.local, split, @@ -644,6 +649,7 @@ def __init__( self.tokenizer = tokenizer self.max_seq_len = max_seq_len + self.packing_ratio = packing_ratio # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -656,11 +662,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: if isinstance(sample['input_ids'], bytes): sample['input_ids'] = np.frombuffer( sample['input_ids'], - dtype=np.int64, + dtype=getattr(np, self.token_encoding_type), )[:self.max_seq_len].tolist().copy() sample['labels'] = np.frombuffer( sample['labels'], - dtype=np.int64, + dtype=getattr(np, self.token_encoding_type), )[:self.max_seq_len].tolist().copy() elif isinstance(sample['input_ids'], np.ndarray): sample['input_ids'] = sample['input_ids'][:self.max_seq_len @@ -675,6 +681,16 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: return {'turns': [sample]} return tokenize_formatted_example(sample, tokenizer=self.tokenizer) + def state_dict(self, num_samples: int, + from_beginning: bool) -> Dict[str, Any]: + if self.packing_ratio is not None: + num_samples = int(self.packing_ratio * num_samples) + + return super().state_dict( + num_samples=num_samples, + from_beginning=from_beginning, + ) + class DatasetConstructor: @@ -792,14 +808,14 @@ def build_from_hf( self, dataset_name: str, split: str, - safe_load: bool, - max_seq_len: int, - preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], - tokenizer: PreTrainedTokenizerBase, - target_prompts: str, - target_responses: str, - decoder_only_format: bool, - hf_kwargs: Dict[str, Any], + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + target_prompts: str = DEFAULT_TARGET_PROMPTS, + target_responses: str = DEFAULT_TARGET_RESPONSES, + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -807,13 +823,45 @@ def build_from_hf( Note: This function will drop examples where the prompt is longer than the max_seq_len Args: - cfg (DictConfig): The dataset configuration. - max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped. - tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset. + dataset_name (str): The name of the HuggingFace dataset + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. + split (str): The split of the HuggingFace dataset. + safe_load (bool, optional): Whether to enforce safe loading of the dataset. + If `None`, will default to not applying any safe loading. + max_seq_len (int): The maximum length of sequences + in the batch. See :class:`Seq2SeqFinetuningCollator` docstring + for details. + preprocessing_fn (Callable, optional): The preprocessing function to use for + formatting the data examples. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenizing + the HuggingFace dataset. + target_prompts (str): Which prompts are used as training targets. + Defaults to "none", meaning prompts are never used as training targets. + See :class:`Seq2SeqFinetuningCollator` docstring for details. + target_responses (str): Which responses are used as training targets. + Defaults to "last", meaning only the final response in multi-turn examples + will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for + details. + decoder_only_format (bool): Whether to format the + examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` + docstring for details. + hf_kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. Returns: Dataset: The tokenized dataset. """ + if hf_kwargs is None: + hf_kwargs = {} + + # None is checked in the function, because argument defaults were added after the function was written and we want + # to preserve the ordering of the arguments for backwards compatibility. + if tokenizer is None: + raise ValueError('A tokenizer must be provided.') + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. @@ -874,7 +922,8 @@ def build_from_hf( f for _, _, files in os.walk(dataset_name) for f in files ] if not all( - Path(f).suffix in SUPPORTED_EXTENSIONS + Path(f).suffix in SUPPORTED_EXTENSIONS + + HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' for f in dataset_files ): raise InvalidFileExtensionError( @@ -899,6 +948,8 @@ def dataset_mapper(example: Dict): detected_cpu_count = os.cpu_count() or 1 detected_cpus_with_margin = detected_cpu_count - 8 num_cpus_to_use = max(1, detected_cpus_with_margin) + if len(dataset) < num_cpus_to_use: + num_cpus_to_use = 1 columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( @@ -959,12 +1010,16 @@ def dataset_mapper(example: Dict): assert filtered_dataset is not None return filtered_dataset + @property + def streaming_dataset_class(self) -> Type[StreamingFinetuningDataset]: + return StreamingFinetuningDataset + def build_from_streaming( self, *args: Any, **kwargs: Any, ) -> StreamingFinetuningDataset: - return StreamingFinetuningDataset(*args, **kwargs) + return self.streaming_dataset_class(*args, **kwargs) dataset_constructor = DatasetConstructor() diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index a6fdf34953..5579066f89 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -337,7 +337,7 @@ def auto_packing_ratio( dataloader_cfg (DictConfig): The dataloader configuration for profiling. tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. device_batch_size (int): The size of the batches (number of examples) per device. - num_packing_ratio (int): The number of packing ratios to try. + num_packing_ratios (int): The number of packing ratios to try. Returns: A packing ratio that minimizes padding while maintaining zero waste. diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 60b81cd145..4bbfc29e7d 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -4,7 +4,6 @@ """Build a StreamingTextDataset dataset and dataloader for training.""" import inspect -import os from itertools import islice from typing import ( Any, @@ -25,6 +24,10 @@ from transformers import PreTrainedTokenizerBase from llmfoundry import registry +from llmfoundry.data import ( + SUPPORTED_MDS_ENCODING_TYPES, + stream_remote_local_validate, +) from llmfoundry.utils.registry_utils import construct_from_registry __all__ = [ @@ -41,6 +44,9 @@ class StreamingTextDataset(StreamingDataset): tokenizer (Tokenizer): HuggingFace tokenizer to tokenize samples. max_seq_len (int): The max sequence length of each sample. + token_encoding_type (str): The encoding type of the tokenized samples. This is only used + for legacy datasets that have been written directly as 'bytes' instead of numpy + arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'. streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -106,6 +112,7 @@ def __init__( self, tokenizer: PreTrainedTokenizerBase, max_seq_len: int, + token_encoding_type: str = 'int64', streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, local: Optional[str] = None, @@ -137,13 +144,21 @@ def __init__( f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}', ) - if local is not None and (remote is None or (local == remote)): - if os.path.isdir(local): - contents = set(os.listdir(local)) - if split not in contents: - raise ValueError( - f'local directory {local} does not contain split {split}', - ) + if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: + raise ValueError( + f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', + ) + self.token_encoding_type = token_encoding_type + + if streams is None: + stream_remote_local_validate(remote, local, split) + else: + for stream in streams: + stream_remote_local_validate( + stream.remote, + stream.local, + split, + ) # TODO: discover where yamls are being converted incorrect, but temporary workaround if isinstance(shuffle_block_size, float): @@ -197,10 +212,18 @@ def _read_binary_tokenized_sample( self, sample: Dict[str, Any], ) -> torch.Tensor: - return torch.from_numpy( - np.frombuffer(sample['tokens'], - dtype=np.int64)[:self.max_seq_len].copy(), - ) + # Modeling code still expects int64 tensors. + if isinstance(sample['tokens'], np.ndarray): + return torch.from_numpy( + sample['tokens'][:self.max_seq_len].copy(), + ).to(torch.int64) + else: + return torch.from_numpy( + np.frombuffer( + sample['tokens'], + dtype=getattr(np, self.token_encoding_type), + )[:self.max_seq_len].copy(), + ).to(torch.int64) # How to process a sample def __getitem__(self, @@ -277,7 +300,7 @@ def build_streams(streams: Optional[Dict[str, Any]] = None,): def build_text_dataloader( tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, + device_batch_size: Union[int, float], dataset: Dict[str, Any], drop_last: bool, num_workers: int, diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index a5fe3a1022..206e884f70 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -26,14 +26,6 @@ def _validate_cfg( eos_token_id = dataset_cfg.get('eos_token_id', None) bos_token_id = dataset_cfg.get('bos_token_id', None) - if eos_token_id is None and bos_token_id is None and ( - hasattr(tokenizer, 'eos_token_id') or - hasattr(tokenizer, 'bos_token_id') - ): - log.warning( - 'The user has not provided an eos_token_id or bos_token_id, but the tokenizer has an eos_token_id or a bos_token_id.', - ) - tokenizer_eos_token_id = getattr(tokenizer, 'eos_token_id', None) if eos_token_id is not None and eos_token_id != tokenizer_eos_token_id: eos_mismatch_str = f'Provided {eos_token_id=} does not match the eos_token_id of the tokenizer={tokenizer_eos_token_id}.' diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py index 02a2b88b21..a3a36053da 100644 --- a/llmfoundry/eval/datasets/__init__.py +++ b/llmfoundry/eval/datasets/__init__.py @@ -22,6 +22,18 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.registry import icl_datasets + +icl_datasets.register( + 'multiple_choice', + func=InContextLearningMultipleChoiceTaskDataset, +) +icl_datasets.register('schema', func=InContextLearningSchemaTaskDataset) +icl_datasets.register('language_modeling', func=InContextLearningLMTaskDataset) +icl_datasets.register( + 'generation_task_with_answers', + func=InContextLearningGenerationTaskWithAnswersDataset, +) __all__ = [ 'InContextLearningDataset', diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index debb0dbc6f..4e49be3fba 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -19,6 +19,7 @@ from datasets import IterableDataset, load_dataset from torch.utils.data import DataLoader, Dataset +from llmfoundry import registry from llmfoundry.eval.datasets.utils import ( convert_tokens_to_tensors, get_continuation_span, @@ -29,6 +30,7 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.utils.registry_utils import construct_from_registry log = logging.getLogger(__name__) @@ -114,11 +116,11 @@ def __init__( max_seq_len: int, pad_tok_id: int, num_fewshot: int, - fewshot_random_seed: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', prelimiter: str = '', context_key: str = 'context', answer_key: str = 'answer', @@ -170,17 +172,26 @@ def __init__( self.dataset = self.dataset.map(strip_data) fewshot_rng = random.Random(fewshot_random_seed) + self._prepared = False + self.num_fewshot = num_fewshot + self.prompt_string = prompt_string + self.fewshot_rng = fewshot_rng + + def _prepare_dataset(self): self.dataset: HFDataset = self.dataset.map( self._prep_example, with_indices=True, fn_kwargs={ - 'num_fewshot': num_fewshot, - 'prompt_string': prompt_string, - 'fewshot_rng': fewshot_rng, + 'num_fewshot': self.num_fewshot, + 'prompt_string': self.prompt_string, + 'fewshot_rng': self.fewshot_rng, }, ) + self._prepared = True def __getitem__(self, index: int) -> Dict: + if not self._prepared: + self._prepare_dataset() return self.dataset[index] def __len__(self) -> int: @@ -189,6 +200,20 @@ def __len__(self) -> int: def get_num_samples_in_batch(self, batch: Dict) -> int: return batch['input_ids'].shape[0] + def get_effective_batch_size(self, batch_size: int) -> int: + r"""Returns effective batch size computed for given ICL task. + + The effective batch size may not be equal to the configured evaluation + batch size because for certain ICL tasks, >1 prompts can get created + for every input query depending on the number of choices/continuations. + This requires the effective batch size to be reduced to prevent larger batches than expected during eval. For example, + check InContextLearningMultipleChoiceTaskDataset. + + Args: + batch_size (int): Original batch size configured for ICL evaluations + """ + return batch_size + def update_generation_kwargs(self, generation_kwargs: Dict) -> None: r"""Updates self.base_batch with the passed in generation_kwargs. @@ -226,8 +251,9 @@ def read_dataset( """ from datasets import \ Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] - from datasets import \ - load_dataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import ( # pyright: ignore[reportGeneralTypeIssues] + load_dataset, + ) if 'hf://' in dataset_uri: dataset_uri = dataset_uri.replace('hf://', '') if hf_loading_vars is None: @@ -338,6 +364,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The answer in the example @@ -519,46 +546,12 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def split_batch(self, batch: Any, - microbatch_size: Union[int, float]) -> Sequence[Any]: - """Handling for certain specialty columns that must be split into. - - batches in different formats. - - Args: - batch (Dict): Batch of data - microbatch_size (int | float): Size of microbatches - - Returns: - List: List of chunked batches - """ - # Don't split kwargs that don't change - # Normally split torch tensors - # List split lists of strings - if isinstance(microbatch_size, float): - raise ValueError( - 'split_batch does not support floating point microbatch_size.', - ) - chunked = {} - for k, v in batch.items(): - if k in self.static_keys: - # Defer broadcasting until we know num_chunks - pass - elif k in self.list_keys: - chunked[k] = _split_list(v, microbatch_size) - elif k in self.tensor_keys: - chunked[k] = _default_split_batch(v, microbatch_size) - else: - raise ValueError(f'Unexpected key {k} in batch splitting') - num_chunks = len(chunked['input_ids']) - for k, v in batch.items(): - if k in self.static_keys: - chunked[k] = [v] * num_chunks - - batched_list = [{k: v[idx] - for k, v in chunked.items()} - for idx in range(num_chunks)] - return batched_list + def split_batch( + self, + batch: Any, + microbatch_size: Union[int, float], + ) -> Sequence[Any]: + return _default_split_batch(batch, microbatch_size) class InContextLearningGenerationTaskWithAnswersDataset( @@ -584,13 +577,31 @@ class InContextLearningGenerationTaskWithAnswersDataset( def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'context', + answer_key: str = 'answer', + strip_dataset: bool = True, + padding_size: Optional[int] = None, + base_batch: Optional[Dict] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, cot_delimiter: str = '', early_stopping_criteria: Optional[List[str]] = None, do_normalization: bool = True, - *args: Any, - **kwargs: Any, ): - if kwargs['tokenizer'].eos_token_id is None: + if tokenizer.eos_token_id is None: raise ValueError( '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`', ) @@ -607,13 +618,32 @@ def __init__( tensor_keys = ['input_ids', 'attention_mask'] list_keys = ['labels'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + context_key=context_key, + answer_key=answer_key, + strip_dataset=strip_dataset, + padding_size=padding_size, + base_batch=base_batch, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + # specific to ICL dataset padding_side='left', tokenize_labels=False, static_keys=static_keys, list_keys=list_keys, tensor_keys=tensor_keys, - *args, - **kwargs, ) # NOTE: set these after init call because they take class vars self.early_stopping_criteria = early_stopping_criteria @@ -635,8 +665,8 @@ def __init__( 'input_ids': self.context_key, 'labels': 'aliases', } - if 'generation_kwargs' in kwargs: - self.update_generation_kwargs(kwargs['generation_kwargs']) + if generation_kwargs: + self.update_generation_kwargs(generation_kwargs) def read_dataset( self, @@ -684,6 +714,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The answer in from the example with chain of thought and delimiter if needed @@ -703,7 +734,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + ctxt (str): The specific example's derived context example (Dict): The example as a dictionary. Returns: @@ -765,6 +796,45 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['generation_kwargs']['stopping_criteria'] = stopping_criteria return batch + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Split batch handling for special columns. + + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + List: List of chunked batches + """ + # Don't split kwargs that don't change + # Normally split torch tensors + # List split lists of strings + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.', + ) + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting until we know num_chunks + pass + elif k in self.list_keys: + chunked[k] = _split_list(v, microbatch_size) + elif k in self.tensor_keys: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + batched_list = [{k: v[idx] + for k, v in chunked.items()} + for idx in range(num_chunks)] + return batched_list + class InContextLearningLMTaskDataset(InContextLearningDataset): """A dataset that constructs batches for in-context learning language. @@ -779,8 +849,50 @@ class InContextLearningLMTaskDataset(InContextLearningDataset): See InContextLearningDataset for more details. """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__( + self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'context', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + static_keys: Optional[List] = None, + list_keys: Optional[List] = None, + ): super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + context_key=context_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset answer_key='continuation', static_keys=['mode'], tensor_keys=[ @@ -800,8 +912,6 @@ def __init__(self, *args: Any, **kwargs: Any): 'labels': 'context', }, padding_side='right', - *args, - **kwargs, ) @@ -833,13 +943,33 @@ class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'query', + tensor_keys: Optional[List] = None, + answer_key: str = 'answer', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + list_keys: Optional[List] = None, choices_key: str = 'choices', static_keys: Optional[List] = None, list_of_tensors_keys: Optional[List] = None, list_of_tuples_keys: Optional[List] = None, list_of_primitives: Optional[List] = None, - *args: Any, - **kwargs: Any, ): self.choices_key = choices_key base_batch = { @@ -850,25 +980,42 @@ def __init__( 'gold_indices': [], 'choice_groupings': [], } - context_key = kwargs.pop('context_key', 'query') - static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) - tensor_keys = kwargs.pop( - 'tensor_keys', - ['input_ids', 'labels', 'attention_mask'], - ) + if not static_keys: + static_keys = ['mode', 'generation_kwargs'] + if not tensor_keys: + tensor_keys = ['input_ids', 'labels', 'attention_mask'] self.list_of_tensors_keys = list_of_tensors_keys or [ 'continuation_indices', ] self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] self.list_of_primitives = list_of_primitives or ['gold_indices'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + answer_key=answer_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset context_key=context_key, base_batch=base_batch, static_keys=static_keys, tensor_keys=tensor_keys, padding_side='right', - *args, - **kwargs, ) self.num_choices = len(self.dataset[0][self.choices_key]) self.batch_mapping_per_choice = { @@ -877,6 +1024,11 @@ def __init__( } self.batch_map_per_example = {'gold_indices': 'gold'} + def get_effective_batch_size(self, batch_size: int) -> int: + batch_size = max(self.num_choices, batch_size) + effective_batchsize = batch_size // self.num_choices + return effective_batchsize + def get_answer_from_example( self, example: Dict, @@ -886,6 +1038,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The full string of the correct answer based on the 'gold' key @@ -904,7 +1057,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + ctxt (str): The specific example's derived context example (Dict): The example as a dictionary. Returns: @@ -980,6 +1133,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: since the batch may consist of multiple questions, the choice_groupings indicates which contiguous sequences of elements in the batch correspond to which question gold_indices indicates which of the [0, N-1] choices is the correct one for each question. + Args: data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) @@ -1019,6 +1173,7 @@ def split_batch(self, batch: Any, and real example, which refers to one possible continuation. As example count and microbatch_size are tracked in logical example, we split logical attributes by microbatch_size and real attributes by microbatch_size * num_choices. + Args: batch (Dict): Batch of data microbatch_size (int | float): Size of microbatches @@ -1095,21 +1250,58 @@ class InContextLearningSchemaTaskDataset( def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + answer_key: str = 'answer', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + list_keys: Optional[List] = None, choices_key: str = 'context_options', - *args: Any, - **kwargs: Any, ): static_keys = ['mode'] tensor_keys = ['input_ids', 'labels', 'attention_mask'] list_of_tensors_keys = ['continuation_indices'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + answer_key=answer_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset choices_key=choices_key, context_key=choices_key, static_keys=static_keys, tensor_keys=tensor_keys, list_of_tensors_keys=list_of_tensors_keys, - *args, - **kwargs, ) self.base_batch = { 'input_ids': [], @@ -1120,6 +1312,11 @@ def __init__( 'choice_groupings': [], } + def get_effective_batch_size(self, batch_size: int) -> int: + batch_size = max(self.num_choices, batch_size) + effective_batchsize = batch_size // self.num_choices + return effective_batchsize + def construct_context( self, example: Dict[str, Any], @@ -1228,7 +1425,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + context_options (str): A list of contexts for this specific example. example (Dict): The example as a dictionary. Returns: @@ -1294,23 +1491,10 @@ def build_icl_dataloader( dataset_uri: str, tokenizer: transformers.PreTrainedTokenizerBase, batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str, # e.g. '' hf_loading_vars: Dict, hf_parsing_map: Dict, - destination_path: str, - prelimiter: str, # e.g. 'Question: ' - cot_delimiter: str, # e.g. ' ### ' - fewshot_random_seed: int, - pass_at_k: int, - generations_per_sample: int, - generation_kwargs: Dict, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, + destination_path: str = '', + kwargs: Optional[Dict[str, Any]] = None, ) -> DataSpec: """Factory method that builds the specific dataset for the specified. @@ -1323,108 +1507,36 @@ def build_icl_dataloader( this might be different) 3. set the `split_batch` function if necessary """ - if icl_task_type == 'multiple_choice': - dataset = InContextLearningMultipleChoiceTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'schema': - dataset = InContextLearningSchemaTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'language_modeling': - dataset = InContextLearningLMTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - elif icl_task_type == 'generation_task_with_answers': - dataset = InContextLearningGenerationTaskWithAnswersDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - cot_delimiter=cot_delimiter, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - else: - raise Exception(f'Unrecognized ICL task type: {icl_task_type}') - + # Add named parameters to kwargs + if kwargs is None: + kwargs = {} + kwargs.update({ + 'dataset_uri': dataset_uri, + 'tokenizer': tokenizer, + 'hf_loading_vars': hf_loading_vars, + 'hf_parsing_map': hf_parsing_map, + 'destination_path': destination_path, + }) + dataset = construct_from_registry( + name=icl_task_type, + registry=registry.icl_datasets, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=kwargs, + ) sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False) - split_batch = None - if isinstance( - dataset, - ( - InContextLearningMultipleChoiceTaskDataset, - InContextLearningGenerationTaskWithAnswersDataset, - ), - ): - split_batch = dataset.split_batch - return DataSpec( DataLoader( dataset, - batch_size=effective_batchsize, + batch_size=dataset.get_effective_batch_size(batch_size), sampler=sampler, collate_fn=dataset.collate_fn, ), device_transforms=None, get_num_samples_in_batch=dataset.get_num_samples_in_batch, - split_batch=split_batch, + split_batch=dataset.split_batch, ) @@ -1442,6 +1554,10 @@ def partition_dataset_by_category( Args: dataset_uri (str): Location of dataset. destination_path (str): Base destination path, we will write a separate partition off this URI for each category. + hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. + Raises: MissingConditionalImportError: If datasets not installed raise exception. @@ -1514,24 +1630,11 @@ def get_icl_task_dataloader( tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str = '', - destination_path: str = '', - question_prelimiter: str = '', # e.g. 'Question: ' - fewshot_random_seed: int = 1234, - pass_at_k: int = 1, - generations_per_sample: int = 1, - cot_delimiter: str = '', has_categories: bool = False, hf_loading_vars: Optional[Dict] = None, hf_parsing_map: Optional[Dict] = None, - generation_kwargs: Optional[Dict] = None, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, + destination_path: str = '', + kwargs: Optional[Dict[str, Any]] = None, ) -> Union[DataSpec, Dict[str, DataSpec]]: r"""Constructs a dataloader (or dataloaders if has_categories is True) @@ -1550,8 +1653,7 @@ def get_icl_task_dataloader( # At this point, hf_model is randomly initialized composer_model = HuggingFaceModel(hf_model, hf_tokenizer) - Example: - + Example: .. testcode:: @@ -1588,28 +1690,12 @@ def get_icl_task_dataloader( The default keys expected are "context" and "answer". tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. batch_size (int): Size of a batch used for eval - max_seq_len (int): The maximum sequence length supported by the model. - pad_tok_id (int): The special token used for padding batches. - num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. - prompt_string (str, default = ''): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). - example_delimiter (str, default = '\\n'): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. - continuation_delimiter: (str, default = ' '): Separator inserted between context and answer in each example (e.g. '\\nA: '). - destination_path: (str, default = ''): This is the local file where remote datasets will be saved. - question_prelimiter: (str, default = ''): Text to be prepended before each context, including few shot examples (e.g. "Question: "). - fewshot_random_seed (int, default = 1234): Random seed to use for fewshot sampling - pass_at_k (int): k for how many chances the model gets to write passing code. - generations_per_sample (int): How many outputs to generate per prompt. Passed in generation_kwargs under "num_return_sequences" and overwritten by generation_kwargs dict. - cot_delimiter (str): Delimiter to place between chain of thoughts and continuations. has_categories: (bool): If ``True``, we will search the dataset file for a category key, and partition the dataset into a separate dataloader for each category occurring in the data. hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. - generation_kwargs (Dict, default = None): A dictionary containing keyword arguments to be passed along to the model's generate function. Overwrites any previously specified generation - keyword args in this function (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig - for more details) - early_stopping (List, default = None): A list of strings that, when found in a model's output, will be treated as a stopping criteria at metric computation time. - Used in generation tasks with CoT - do_normalization (bool, default = True): Whether or not to normalize the outputs and labels in InContextLearningGenerationTaskWithAnswersDataset. Only used in generation tasks. + destination_path: Where the dataloader will be saved. + kwargs (Dict[str, Any], default=None): Dictionary containing a mapping from ICL dataset constructor's parameter names and their desired values. Returns: DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. @@ -1618,11 +1704,6 @@ def get_icl_task_dataloader( hf_loading_vars = {} if hf_parsing_map is None: hf_parsing_map = {} - if generation_kwargs is None: - generation_kwargs = {} - if early_stopping_criteria is None: - early_stopping_criteria = [] - if has_categories: result_dls = {} output_files = partition_dataset_by_category( @@ -1639,23 +1720,10 @@ def get_icl_task_dataloader( dataset_uri=partition_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, destination_path=partition_uri + '_tmp', - prelimiter=question_prelimiter, - cot_delimiter=cot_delimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, + kwargs=kwargs, ) return result_dls else: @@ -1664,21 +1732,8 @@ def get_icl_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=continuation_delimiter, destination_path=destination_path, - prelimiter=question_prelimiter, - cot_delimiter=cot_delimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample, - generation_kwargs=generation_kwargs, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, + kwargs=kwargs, ) diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py index 1ce249437d..c19ae15dd9 100644 --- a/llmfoundry/eval/datasets/utils.py +++ b/llmfoundry/eval/datasets/utils.py @@ -130,7 +130,7 @@ def make_padded_input( Args: context_enc (List): The encoded input to the model continuation_enc (List): The encoded desired output for the example - max_seq_list (int): Maximum length sequences can be + max_seq_len (int): Maximum length sequences can be pad_tok_id (int): The token id we pad with padding_side (str): Which side to pad the context on. Can be 'right' or 'left diff --git a/llmfoundry/eval/metrics/nlp.py b/llmfoundry/eval/metrics/nlp.py index 3ee30ebf5e..f0fbba3ece 100644 --- a/llmfoundry/eval/metrics/nlp.py +++ b/llmfoundry/eval/metrics/nlp.py @@ -80,7 +80,7 @@ def update( Args: batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed to compute the metric. - output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` + outputs (torch.Tensor): The model outputs evaluated on the batch `input_ids`. labels (torch.Tensor): The correct outputs. Raises: diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index e618d03dc8..50a4906ec1 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -7,32 +7,65 @@ from llmfoundry.utils.registry_utils import create_registry -_norm_description = ( - 'The norms registry is used to register classes that implement normalization layers.' +_norms_description = ( + """The norms registry is used to register classes that implement normalization layers. + + One example of this is torch.nn.LayerNorm. See norm.py for examples. + + Args: + normalized_shape Union[int, List[int], torch.Size]: The shape of the input tensor. + device: Optional[torch.device]: The device to use for the normalization layer. + + Returns: + torch.nn.Module: The normalization layer. + """ ) norms = create_registry( 'llmfoundry', 'norms', generic_type=Type[torch.nn.Module], entry_points=True, - description=_norm_description, + description=_norms_description, ) -_fc_description = ( - 'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).' - + - 'These classes should take in_features and out_features in as args, at a minimum.' + +_fcs_description = ( + """The fcs registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear). + + See fc.py for examples. + + Args: + in_features: int: The number of input features. + out_features: int: The number of output features. + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The fully connected layer. + """ ) fcs = create_registry( 'llmfoundry', 'fcs', generic_type=Type[torch.nn.Module], entry_points=True, - description=_fc_description, + description=_fcs_description, ) _ffns_description = ( - 'The ffns registry is used to register functions that build ffn layers.' + - 'See ffn.py for examples.' + """The ffns registry is used to register functions that build FFN layers. + + These layers are generally composed of fc layers and activation functions. + One example is MPTMLP. See ffn.py for examples. + + Args: + d_model: int: The size of the input and output tensors. + expansion_ratio: float: The expansion ratio for the hidden layer. + device: Optional[str]: The device to use for the layer. + bias: bool: Whether or not to include a bias term. + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The FFN layer. + """ ) ffns = create_registry( 'llmfoundry', @@ -43,8 +76,21 @@ ) _ffns_with_norm_description = ( - 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' - + 'See ffn.py for examples.' + """The ffns_with_norm registry is used to register functions that build FFN layers with normalization. + + The resulting layer will have ._has_norm set on it. + One example is te.LayerNormMLP. See ffn.py for examples. + + Args: + d_model: int: The size of the input and output tensors. + expansion_ratio: float: The expansion ratio for the hidden layer. + device: Optional[str]: The device to use for the layer. + bias: bool: Whether or not to include a bias term. + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The FFN layer. + """ ) ffns_with_norm = create_registry( 'llmfoundry', @@ -58,6 +104,16 @@ 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + 'See ffn.py for examples.' ) +_ffns_with_megablocks_description = ( + """The ffns_with_megablocks registry is used to register functions that build FFN layers using MegaBlocks. + + The resulting layer will have ._uses_megablocks set on it. + One example is megablocks.layers.dmoe.dMoE. See ffn.py for examples. + + Returns: + torch.nn.Module: The FFN layer. + """ +) ffns_with_megablocks = create_registry( 'llmfoundry', 'ffns_with_megablocks', @@ -67,8 +123,17 @@ ) _attention_classes_description = ( - 'The attention_classes registry is used to register classes that implement attention layers. See ' - + 'attention.py for expected constructor signature.' + """The attention_classes registry is used to register classes that implement attention layers. + + The kwargs are passed directly to the constructor of the class. + One example is GroupedQueryAttention. See attention.py for examples. + + Args: + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The attention layer. + """ ) attention_classes = create_registry( 'llmfoundry', @@ -79,8 +144,29 @@ ) _attention_implementations_description = ( - 'The attention_implementations registry is used to register functions that implement the attention operation.' - + 'See attention.py for expected function signature.' + """The attention_implementations registry is used to register functions that implement the attention operation. + + One example is 'flash'. See attention.py for examples. + + Args: + query (torch.Tensor): The query tensor. + key (torch.Tensor): The key tensor. + value (torch.Tensor): The value tensor. + n_heads (int): The number of attention heads. + kv_n_heads (int): The number of attention heads for the key and value tensors. + past_key_value (Optional[tuple[torch.Tensor, torch.Tensor]]): The past key and value tensors. + softmax_scale (Optional[float]) = None + attn_bias (Optional[torch.Tensor]) = None + is_causal (bool) = False + dropout_p (float) = 0.0 + training (bool) = True + needs_weights (bool) = False + kwargs: Dict[str, Any]: Additional keyword arguments the implementation accepts. + + Returns: + tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + The output tensor, the attention weights, and the past key and value tensors. + """ ) attention_implementations = create_registry( 'llmfoundry', @@ -91,9 +177,17 @@ ) _param_init_fns_description = ( - 'The param_init_fns registry is used to register functions that initialize parameters.' - + - 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' + """The param_init_fns registry is used to register functions that initialize parameters. + + These functions should take in a torch.nn.Module, additional kwargs, and initialize the parameters of the module. + Generally they can call generic_param_init_fn_ with an appropriate partial function. See param_init_fns.py for examples. + + Note: These functions should take in arbitrary kwargs, and discard any they don't need. + + Args: + module: torch.nn.Module: The module to initialize. + kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization. + """ ) param_init_fns = create_registry( 'llmfoundry', @@ -103,9 +197,23 @@ description=_param_init_fns_description, ) -_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. -These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. -They should take in the module, init_div_is_residual, and div_is_residual arguments.""" +_module_init_fns_description = ( + """The module_init_fns registry is used to register functions that initialize specific modules. + + These functions should return True if they initialize the module, and False otherwise. + This allows them to be called without knowing their contents. They should take in the module and additional kwargs. + If multiple functions can initialize the module, the one that is registered first will be used, so it is recommended to + override an existing function if you want to change existing initialization behavior, and add new functions if you have new + layer types. See param_init_fns.py for details. + + Args: + module: torch.nn.Module: The module to initialize. + kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization. + + Returns: + bool: Whether or not the module was initialized. + """ +) module_init_fns = create_registry( 'llmfoundry', 'module_init_fns', diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 5f3a53ed18..34ce22d694 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -11,7 +11,6 @@ Any, Dict, List, - Mapping, Optional, Tuple, Union, @@ -23,7 +22,6 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, - PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -36,7 +34,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import get_hf_config_value +from llmfoundry.utils.config_utils import set_config_overrides if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -90,6 +88,7 @@ def __init__( use_train_metrics: bool = True, additional_train_metrics: Optional[List] = None, additional_eval_metrics: Optional[List] = None, + should_save_peft_only: bool = True, ): config_overrides = config_overrides or {} @@ -104,9 +103,13 @@ def __init__( config_overrides=config_overrides, load_in_8bit=load_in_8bit, pretrained=pretrained, - prepare_for_fsdp=True, + prepare_for_fsdp=False, ) + model = self.transform_model(model) + + ComposerHFCausalLM.prepare_inner_model(model, init_device) + train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( use_train_metrics=use_train_metrics, additional_train_metrics=additional_train_metrics, @@ -120,7 +123,7 @@ def __init__( peft_config_object = None if peft_config is not None: - peft_config_object = self._get_peft_config(peft_config) + peft_config_object = self.get_peft_config(peft_config) # Set up config args for the model construction and base classes super().__init__( @@ -131,8 +134,20 @@ def __init__( eval_metrics=eval_metrics, init_device=init_device, peft_config=peft_config_object, + should_save_peft_only=should_save_peft_only, ) + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + """Transforms the model after initialization. + + Args: + model (PreTrainedModel): The model to transform. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + @staticmethod def build_metrics( use_train_metrics: bool, @@ -190,6 +205,7 @@ def build_inner_model( use_auth_token (bool): Whether to use an authentication token. config_overrides (Dict[str, Any]): The configuration overrides. load_in_8bit (bool): Whether to load in 8-bit. + pretrained (bool): Whether the model is pretrained. prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False. Returns: @@ -214,6 +230,22 @@ def build_inner_model( + 'Please `pip install llm-foundry[gpu]`.', ) + # Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the config on local rank + # zero. This will set up the transformers module cache and avoid the future contention. + if dist.get_local_rank() == 0: + AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 + ) + + dist.barrier() + # Construct the Hugging Face config to use config = AutoConfig.from_pretrained( pretrained_model_name_or_path, @@ -241,70 +273,33 @@ def _autoset_attn_implementation_monkeypatch( _autoset_attn_implementation_monkeypatch, ) - # set config overrides - for k, v in config_overrides.items(): - if not hasattr(config, k): - raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).', - ) - - attr = getattr(config, k) - # attempt to disallow typos in nested configs - if isinstance(attr, Mapping): - extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] - if extra_keys: - raise ValueError( - f'Config dict override got unknown keys. ' + - f'Extra keys: {extra_keys}. ' + - f'Expected (a subset of) keys: {list(attr.keys())}.', - ) - getattr(config, k).update(v) - # necessary case to allow for rope_scaling to be overriden in llama config - elif attr is None and isinstance(v, Mapping): - setattr(config, k, {}) - getattr(config, k).update(v) - elif isinstance(attr, PretrainedConfig): - if not isinstance(v, Mapping): - raise ValueError( - f'Expected a dictionary for config override {k}, but got {v}.', - ) - - for _k, _v in v.items(): - if not hasattr(attr, _k): - raise ValueError( - f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', - ) - setattr(attr, _k, _v) - else: - setattr(config, k, v) - - if hasattr(config, 'attn_config') and get_hf_config_value( - config.attn_config, - 'seq_parallel_world_size', - ) is not None: - raise NotImplementedError( - 'Sequence Parallelism is not supported for HuggingFace models.', - ) + set_config_overrides(config, config_overrides) # We need to have all non-zero local ranks be not-pretrained # Rank 0 will still be pretrained, and distribute the weights appropriately if dist.get_local_rank() != 0 and init_device == 'mixed': pretrained = False - # If the HuggingFace model is coming from a local folder, Hugging Face copies the modules into the + # Hugging Face copies the modules into the # transformers modules cache. On particular systems, this operation seems to cause contention between # the different processes. To avoid this contention, we first create the model (on meta device) on local rank # zero. This will set up the transformers model cache and avoid the future contention. - if dist.get_local_rank( - ) == 0 and os.path.isdir(pretrained_model_name_or_path): - with init_empty_weights(include_buffers=False): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', UserWarning) - AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, + if dist.get_local_rank() == 0: + if pretrained and os.path.isdir(pretrained_model_name_or_path): + with init_empty_weights(include_buffers=False): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + config=config, + ) + else: + with init_empty_weights(include_buffers=False): + AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - config=config, ) dist.barrier() @@ -371,10 +366,10 @@ def _autoset_attn_implementation_monkeypatch( if prepare_for_fsdp: ComposerHFCausalLM.prepare_inner_model(model, init_device) + return model - @staticmethod - def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': + def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: from peft import LoraConfig peft_type = peft_config_dict.get('peft_type', '') diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index c667c6026a..7051986df8 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -40,6 +40,7 @@ def __init__( shift_labels: bool = False, init_device: Optional[str] = None, peft_config: Optional['PeftConfig'] = None, + should_save_peft_only: bool = True, ): super().__init__( model, @@ -49,7 +50,7 @@ def __init__( eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, - should_save_peft_only=True, + should_save_peft_only=should_save_peft_only, ) self.prepare_inner_model(self.model, init_device) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9b34190edf..3e365edc47 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -212,7 +212,7 @@ def check_valid_inputs( valid_dtypes: Optional[list[torch.dtype]] = None, ): if valid_dtypes is None: - valid_dtypes = [torch.float16, torch.bfloat16] + valid_dtypes = [torch.float32, torch.float16, torch.bfloat16] for tensor in tensors: if tensor.dtype not in valid_dtypes: raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.') @@ -266,11 +266,13 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - indices_q = flash_attn_padding_info['indices_q'] - indices_k = flash_attn_padding_info['indices_k'] - indices_v = flash_attn_padding_info['indices_v'] - cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'] - cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'] + # In the following lines we move the tensors to the same devices as query, key, and value respectively. These operations should be no-ops during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + indices_q = flash_attn_padding_info['indices_q'].to(query.device) + indices_k = flash_attn_padding_info['indices_k'].to(key.device) + indices_v = flash_attn_padding_info['indices_v'].to(value.device) + cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'].to(query.device) + cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'].to(key.device) max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] @@ -409,13 +411,16 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__() @@ -423,11 +428,13 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.qk_gn = qk_gn + self.fused_qkv = fused_qkv self.d_model = d_model self.n_heads = n_heads self.kv_n_heads = kv_n_heads self.sliding_window_size = sliding_window_size + self.reuse_kv_layer_idx = reuse_kv_layer_idx self.head_dim = d_model // n_heads @@ -458,33 +465,74 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - self.Wqkv = build_fc( - name=fc_type_name, - in_features=self.d_model, - out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, - fc_kwargs=fc_type, - ) - # for param init fn; enables shape based init of fused layers - fuse_splits = [ - i * self.head_dim - for i in range(1, self.n_heads + 2 * self.kv_n_heads) - ] - self.Wqkv._fused = (0, fuse_splits) + if self.reuse_kv_layer_idx is not None: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) + elif self.fused_qkv: + self.Wqkv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [ + i * self.head_dim + for i in range(1, self.n_heads + 2 * self.kv_n_heads) + ] + self.Wqkv._fused = (0, fuse_splits) + else: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + self.Wk = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + self.Wv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + kv_fuse_splits = [ + i * self.head_dim for i in range(1, self.kv_n_heads) + ] + self.Wq._fused = (0, q_fuse_splits) + self.Wk._fused = (0, kv_fuse_splits) + self.Wv._fused = (0, kv_fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model self.q_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) - if qk_ln: - norm_size = self.head_dim * kv_n_heads - self.k_ln = build_norm( - name=norm_type.lower(), - normalized_shape=norm_size, - device=device, - ) + if self.reuse_kv_layer_idx is None: + if qk_ln: + norm_size = self.head_dim * kv_n_heads + self.k_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + eps=norm_eps, + device=device, + ) self.attn_fn = attention_implementations.get(self.attn_impl) @@ -507,9 +555,14 @@ def forward( needs_weights: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: - query, key, value = self.get_qkv(x) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + query, key, value = self.get_qkv(x, **extra_kwargs) if rotary_emb_w_meta_info is not None: query, key, value = self._apply_rotary_embeddings( @@ -546,30 +599,64 @@ def forward( def get_qkv( self, x: torch.Tensor, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. Args: x (torch.Tensor): The input tensor. + prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. Returns: query (torch.Tensor): The query tensor. key (torch.Tensor): The key tensor. value (torch.Tensor): The value tensor. """ - qkv = self.Wqkv(x) - - if self.clip_qkv: - qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) - - query, key, value = qkv.split( - [ - self.d_model, - self.kv_n_heads * self.head_dim, - self.kv_n_heads * self.head_dim, - ], - dim=2, - ) + if self.reuse_kv_layer_idx is not None: + if prev_layer_key_value is None: + raise ValueError( + 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', + ) + key, value = prev_layer_key_value + + query = self.Wq(x) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + if self.qk_ln or self.qk_gn: + # Applying layernorm to qk + q_shape = query.shape + if self.qk_gn: + b, s = query.shape[:2] + query = query.view(b, s, self.n_heads, -1) + dtype = query.dtype + query = self.q_ln(query).to(dtype).view(q_shape) + return query, key, value + + if self.fused_qkv: + qkv = self.Wqkv(x) + + if self.clip_qkv: + qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query, key, value = qkv.split( + [ + self.d_model, + self.kv_n_heads * self.head_dim, + self.kv_n_heads * self.head_dim, + ], + dim=2, + ) + else: + query = self.Wq(x) + key = self.Wk(x) + value = self.Wv(x) + + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv) + value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv) if self.qk_ln or self.qk_gn: # Applying layernorm to qk @@ -591,6 +678,10 @@ def _apply_rotary_embeddings( key: torch.Tensor, value: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.reuse_kv_layer_idx is not None: + orig_key, orig_value = key, value + key, value = torch.empty_like(key), torch.empty_like(value) + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] offset_info = rotary_emb_w_meta_info['offset_info'] @@ -602,6 +693,7 @@ def _apply_rotary_embeddings( value = value.view(bsz, seqlen, -1, self.head_dim) kv = torch.stack([key, value], dim=2) + # Note: Rotates in place (https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/flash_attn/layers/rotary.py#L429) query, kv = rotary_emb( query, kv, @@ -620,6 +712,10 @@ def _apply_rotary_embeddings( else: (cos, sin) = rotary_emb(x=value, seq_len=seq_len) if is_transformers_version_gte('4.38'): + # In the following lines we move the cos and sin tensors to the same devices as query. These operations should be no-ops during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + cos = cos.to(query.device) + sin = sin.to(query.device) query, key = apply_rotary_pos_emb( q=query, k=key, @@ -652,6 +748,8 @@ def _apply_rotary_embeddings( query = query.view(bsz, seqlen, -1) key = key.view(bsz, seqlen, -1) + if self.reuse_kv_layer_idx is not None: + return query, orig_key, orig_value # type: ignore return query, key, value def get_implementation_specific_args( @@ -698,13 +796,16 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -714,13 +815,16 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size, + reuse_kv_layer_idx=reuse_kv_layer_idx, ) @@ -739,13 +843,16 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, sliding_window_size: int = -1, + reuse_kv_layer_idx: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -755,13 +862,16 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size, + reuse_kv_layer_idx=reuse_kv_layer_idx, ) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index b4b64b9a0a..92735cc489 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,6 +42,7 @@ def __init__( ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, no_bias: bool = False, @@ -84,6 +85,7 @@ def __init__( fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, + norm_eps=norm_eps, device=device, no_bias=no_bias, ) @@ -99,6 +101,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -117,6 +120,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) @@ -158,8 +162,13 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -171,6 +180,7 @@ def forward( output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) else: a = self.norm_1(x) @@ -184,6 +194,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) x = x + self.resid_attn_dropout(b) m = x @@ -191,7 +202,9 @@ def forward( m = self.norm_2(x) n = self.apply_ffn(attention_mask, m) - x = x + self.resid_ffn_dropout(n) + # In the following line we move the `x` tensor to the same devices as the output of ffn layer. This operation should be a no-op during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + x = x.to(device=n.device) + self.resid_ffn_dropout(n) return x, attn_weights, past_key_value def apply_ffn( @@ -212,6 +225,7 @@ def apply_ffn( indices = None if not self.use_pad_tok_in_ffn and attention_mask is not None: assert unpad_input is not None + attention_mask = self.slice_attention_mask(attention_mask, seq_len) m, indices, _, _ = unpad_input(m, attention_mask) n = self.ffn(m) if not self.use_pad_tok_in_ffn and attention_mask is not None: @@ -219,6 +233,24 @@ def apply_ffn( n = pad_input(n, indices, batch_size, seq_len) return n + def slice_attention_mask( + self, + attention_mask: torch.ByteTensor, + seq_len: int, + ) -> torch.ByteTensor: + """Slice attention mask to the correct size. + + Can be overridden by subclasses to apply different slicing logic. + + Args: + attention_mask (torch.ByteTensor): The attention mask. + seq_len (int): The sequence length. + + Returns: + torch.ByteTensor: The sliced attention mask. + """ + return attention_mask + class FusedNormAttentionNorm(nn.Module): @@ -232,6 +264,7 @@ def __init__( fc_type: Optional[dict[str, Any]] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, device: Optional[str] = None, no_bias: bool = False, **kwargs: Any, @@ -255,6 +288,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -274,6 +308,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) @@ -289,9 +324,14 @@ def forward( output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + prev_layer_key_value: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value, @@ -302,6 +342,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index 6190dbc6ea..59508e0a50 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -280,9 +280,13 @@ def forward( expert_tokens = x[None, token_list].reshape(-1, hidden_size) mlp_output = self.mlp(expert_tokens, expert_idx) + # In the following lines we move tensors to the same devices as the output of mlp. These operations should be no-ops during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + expert_weights = expert_weights.to(mlp_output.device) expert_out = mlp_output * expert_weights[token_list, topk_list, None] - + out = out.to(mlp_output.device) + token_idx = token_idx.to(mlp_output.device) out.index_add_(0, token_idx, expert_out) out = out.view(in_shape) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a28725ee0f..f5d6d67040 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -53,6 +53,19 @@ } +def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: + """Applies GELU approximation that is fast but somewhat inaccurate. + + Args: + input (torch.Tensor): Input tensor of shape(*), where * means any + number of dimensions + + Returns: + torch.Tensor: Tensor with same shape as input tensor + """ + return input * torch.sigmoid(1.702 * input) + + def resolve_ffn_act_fn( config: Optional[dict] = None, ) -> Callable[[torch.Tensor], torch.Tensor]: @@ -70,10 +83,13 @@ def resolve_ffn_act_fn( config = _FFN_ACT_FN_DEFAULT config = deepcopy(config) name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognized activation function name ({name}).') - act = getattr(torch.nn.functional, name) - return partial(act, **config) + if name == 'quick_gelu': + return quickgelu_activation + else: + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognized activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) @@ -413,6 +429,7 @@ def set_ffn_device_mesh( ffn (nn.Module): The FFN module. moe_world_size (int): The MoE world size. device_mesh (DeviceMesh): The full device mesh. + get_fsdp_submesh (Callable[[DeviceMesh], DeviceMesh]): A function to get the fsdp submesh. Raises: RuntimeError: If the device mesh is 3D. diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 69d2059bad..d5fd1d37d4 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -26,10 +26,12 @@ def build_norm( name: str, normalized_shape: Union[int, List[int], torch.Size], + eps: Optional[float] = 1e-5, device: Optional[str] = None, ): kwargs = { 'normalized_shape': normalized_shape, + 'eps': eps, 'device': device, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 5f56fef56f..9671eb6ed5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -14,22 +14,13 @@ check_alibi_support, is_flash_v2_installed, ) - -# NOTE: All utils are imported directly even if unused so that -# HuggingFace can detect all the needed files to copy into its modules folder. -# Otherwise, certain modules are missing. -# isort: off -from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) -from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note) -from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note) -from llmfoundry.layers_registry import norms # type: ignore (see note) -from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) from llmfoundry.models.utils.config_defaults import ( attn_config_defaults, + fc_type_defaults, ffn_config_defaults, init_config_defaults, - fc_type_defaults, -) # type: ignore (see note) +) +from llmfoundry.utils.warnings import ExperimentalWarning class MPTConfig(PretrainedConfig): @@ -53,11 +44,13 @@ def __init__( no_bias: bool = False, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, use_cache: bool = False, init_config: Optional[Dict] = None, fc_type: Union[str, Dict] = 'torch', tie_word_embeddings: bool = True, use_pad_tok_in_ffn: bool = True, + block_overrides: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """The MPT configuration class. @@ -78,6 +71,8 @@ def __init__( attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer. + fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single + Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, @@ -107,6 +102,7 @@ def __init__( no_bias (bool): Whether to use bias in all layers. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use + norm_eps (float): epsilon value for norm layer use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', @@ -127,6 +123,31 @@ def __init__( also be a dictionary that specifies the fc layer name and any kwargs for the fc layer. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. + block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and `order`. `order` is a nested list which describes the order of the layers. For each kind of layer, specify the `overrides` in the overrides config (default refers to a layer that does not apply any overrides). + To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed: + block_overrides: + order: + - name: default + - repeat: 2 + order: + - name: sliding_window_layer + - name: sliding_window_layer_reuse + - name: sliding_window_layer + - repeat: 2 + name: sliding_window_layer_reuse + - name: reuse_kv_layer + overrides: + sliding_window_layer: + attn_config: + sliding_window_size: 1024 + sliding_window_layer_reuse: + attn_config: + sliding_window_size: 1024 + reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse + reuse_kv_layer: + attn_config: + reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse + kwargs (Any): Other relevant keyword arguments. """ self.d_model = d_model self.n_heads = n_heads @@ -150,11 +171,21 @@ def __init__( self.no_bias = no_bias self.embedding_fraction = embedding_fraction self.norm_type = norm_type + self.norm_eps = norm_eps self.use_cache = use_cache self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, ) + if 'reuse_kv_layer_idx' in self.attn_config and self.attn_config[ + 'attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) + if block_overrides is not None: + self._validate_block_overrides(block_overrides) + self.block_overrides = block_overrides + if isinstance(fc_type, str): fc_type = {'name': fc_type} self.fc_type = fc_type @@ -179,6 +210,23 @@ def __init__( self._validate_config() + def _validate_block_overrides(self, block_overrides: Dict[str, Any]): + warnings.warn(ExperimentalWarning('block_overrides')) + if 'order' not in block_overrides: + raise ValueError('`order` should be defined in block_overrides',) + if 'overrides' not in block_overrides: + raise ValueError( + '`overrides` should be defined in block_overrides', + ) + for name, override in block_overrides['overrides'].items(): + if name == 'default': + raise ValueError('block overrides cannot be named "default".',) + if 'attn_config' in override and 'reuse_kv_layer_idx' in override[ + 'attn_config'] and self.attn_config['attn_impl'] == 'torch': + raise NotImplementedError( + 'reusing kv cache from a previous layer is not implemented for torch attention.', + ) + def _set_config_defaults( self, config: Dict[str, Any], @@ -196,6 +244,13 @@ def _set_config_defaults( ) return config + def validate_attention_config(self) -> None: + if 'seq_parallel_world_size' in self.attn_config and self.attn_config[ + 'seq_parallel_world_size'] is None: + del self.attn_config['seq_parallel_world_size'] + if self.attn_config.get('seq_parallel_world_size', 1) > 1: + raise NotImplementedError('Sequence Parallelism is not supported.') + def _validate_config(self) -> None: # set config defaults self.attn_config = self._set_config_defaults( @@ -255,6 +310,7 @@ def _validate_config(self) -> None: 'no_scaling', 'linear', 'dynamic', + 'llama3', ]: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".', @@ -336,5 +392,14 @@ def _validate_config(self) -> None: raise ImportError( 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6', ) - if (self.attn_config.get('seq_parallel_world_size', 1) or 1) > 1: - raise NotImplementedError('Sequence Parallelism is not supported.') + + self.validate_attention_config() + + @property + def allowed_block_overrides(self): + return { + 'attn_config': { + 'sliding_window_size': None, + 'reuse_kv_layer_idx': None, + }, + } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9d18799e93..6f9b6bf806 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy import math import warnings from functools import cached_property @@ -28,6 +29,7 @@ import torch.nn.functional as F from composer.models import HuggingFaceModel from composer.utils import dist +from tabulate import tabulate from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.attention import is_flash_v2_installed @@ -47,12 +49,10 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaConfig, + LlamaRotaryEmbedding, +) from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import ( @@ -80,19 +80,68 @@ # isort: off from llmfoundry.models.layers.fc import fcs # type: ignore from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore +from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore # isort: on log = logging.getLogger(__name__) +class InvalidConfigAccessError(KeyError): + pass + + +_ALLOWED_LLAMA_CONFIG_KEYS = { + # These are the only config keys that are set and are safe to read from + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + + # Not set but llama modeling code tries to read this attribute + 'partial_rotary_factor', + + # Benign transformers attributes needed for __init__ + '_get_generation_defaults', + 'label2id', + 'id2label', + 'torch_dtype', + 'problem_type', + '__class__', +} + + +class PartialLlamaConfig(LlamaConfig): + """Holds the rope config for Llama models and throws. + + an `InvalidConfigAccessError` if any other config elements are read. This + class is necessary because the `LlamaRotaryEmbedding` class takes a full + `LlamaConfig` now instead of the old keyword arguments. + """ + + def __getattribute__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getattribute__(key) + + def __getitem__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getitem__(key) + + def gen_rotary_embedding( - rope_head_dim: int, rope_impl: str, rope_theta: int, rope_dail_config: dict, rope_hf_config: dict, max_seq_len: int, + d_model: int, + n_heads: int, ): + rope_head_dim = d_model // n_heads if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -105,32 +154,21 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_impl == 'hf': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = llama_rope_config.pop('type') + if llama_rope_config['rope_type'] == 'no_scaling': + llama_rope_config['rope_type'] = 'default' + partial_llama_config = PartialLlamaConfig( + rope_scaling=llama_rope_config, + rope_theta=rope_theta, + max_position_embeddings=max_seq_len, + hidden_size=d_model, + num_attention_heads=n_heads, + ) if rope_hf_config['type'] == 'no_scaling': - return HFRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) + return LlamaRotaryEmbeddingFoundry(config=partial_llama_config) + elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: + return LlamaRotaryEmbedding(config=partial_llama_config) raise ValueError('rope_impl needs to be either dail or hf') @@ -303,6 +341,20 @@ def apply_sequence_id( return attn_bias +class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding): + + @torch.no_grad() + def forward( + self, + x: torch.Tensor, + position_ids: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # In this subclass, we move `inv_freq` to same device as position_ids. This operation should be a no-op during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1334#issue-2387337525 + self.inv_freq = self.inv_freq.to(position_ids.device) + return super().forward(x=x, position_ids=position_ids) + + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -374,6 +426,7 @@ def __init__(self, config: MPTConfig): self.norm_f = build_norm( name=config.norm_type.lower(), normalized_shape=config.d_model, + eps=config.norm_eps, device=config.init_device, ) @@ -382,12 +435,13 @@ def __init__(self, config: MPTConfig): if self.rope: self.rope_impl = config.attn_config['rope_impl'] self.rotary_embedding = gen_rotary_embedding( - rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len, + d_model=config.d_model, + n_heads=config.n_heads, ) if config.init_device != 'meta': @@ -425,6 +479,10 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') + @property + def block_class(self) -> Type[MPTBlock]: + return MPTBlock + def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: """Construct the nn.ModuleList with the Transformer blocks. @@ -435,14 +493,181 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) + self.kv_cache_layers = set() + self.blocks_fuse_norm_attn_norm = block_args.get( + 'fuse_norm_attn_norm', + False, + ) + + if config.block_overrides is not None: + block_args_list = self._get_override_block_args_list( + config, + block_args, + ) + else: + block_args_list = [block_args for _ in range(config.n_layers)] return nn.ModuleList([ - MPTBlock( + self.block_class( device=config.init_device, - **block_args, - ) for _ in range(config.n_layers) + **block_args_i, + ) for block_args_i in block_args_list ]) + def _get_override_block_args_list( + self, + config: MPTConfig, + block_args: Dict[str, Any], + ) -> List[Dict[str, Any]]: + if config.block_overrides is None: + raise ValueError( + 'config.block_overrides should not be None when calling _get_override_block_args_list.', + ) + repeat = config.block_overrides.get('repeat', 1) + model_modules_order_expanded = MPTModel._get_modules_order_expanded( + config.block_overrides['order'], + ) * repeat + if len(model_modules_order_expanded) != config.n_layers: + raise ValueError( + f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.', + ) + + new_block_args_list = [] + layer_description_list = [] + + reuse_kv_layer_idx_dict = {} + for b_idx in range(config.n_layers): + module_name = model_modules_order_expanded[b_idx] + override_config = {} + if module_name != 'default': + override_config = copy.deepcopy( + config.block_overrides['overrides'][module_name], + ) + if 'reuse_kv_layer_idx' in override_config.get( + 'attn_config', + {}, + ): + reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_idx( + overrides_definition=config. + block_overrides['overrides'], + model_modules_order_expanded= + model_modules_order_expanded, + b_idx=b_idx, + override_config=override_config, + reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, + ) + override_config['attn_config']['reuse_kv_layer_idx' + ] = reuse_kv_layer_idx + self.kv_cache_layers.add(reuse_kv_layer_idx) + layer_description_list.append([ + b_idx, + module_name, + override_config, + ],) + new_block_args_list.append( + MPTModel._override_block_args( + block_args, + override_config, + config.allowed_block_overrides, + ), + ) + log.info( + 'The following is a summary of overrides per layer.\n' + tabulate( + layer_description_list, + headers=['idx', 'name', 'overrides'], + ), + ) + return new_block_args_list + + @staticmethod + def _resolve_reuse_kv_layer_idx( + overrides_definition: Dict[str, Any], + model_modules_order_expanded: List[str], + b_idx: int, + override_config: Dict[str, Any], + reuse_kv_layer_idx_dict: Dict[int, int], + ) -> int: + override_attn_config = override_config['attn_config'] + if override_attn_config['reuse_kv_layer_idx'] >= 0: + raise ValueError( + f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', + ) + reuse_kv_layer_idx = b_idx + override_attn_config['reuse_kv_layer_idx'] + if reuse_kv_layer_idx < 0: + raise ValueError( + f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', + ) + if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: + reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx] + reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx + + parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] + parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( + overrides_definition[parent_layer_name], + ) + if 'attn_config' not in parent_config: + parent_config['attn_config'] = {} + parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[ + 'attn_config']['reuse_kv_layer_idx'] + + if override_config != parent_config and not ( + 'allow_mismatch' in override_config and + override_config['allow_mismatch'] + ): + raise ValueError( + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', + ) + + return reuse_kv_layer_idx + + @staticmethod + def _get_modules_order_expanded(order: List[Dict[str, Any]]) -> List[str]: + model_modules_order_expanded = [] + for item in order: + repeat = item['repeat'] if 'repeat' in item else 1 + if ('name' in item) == ('order' in item): + raise ValueError( + 'Exactly one of `order` or `name` must be specified for each block override.', + ) + + if 'name' in item: + model_modules_order_expanded.extend([item['name']] * repeat) + else: + model_modules_order_expanded.extend( + MPTModel._get_modules_order_expanded(item['order']) * + repeat, + ) + + return model_modules_order_expanded + + @staticmethod + def _override_block_args( + block_args: Dict[str, Any], + override_config: Dict[str, Any], + allowed_block_overrides: Dict[str, Any], + ) -> Dict[str, Any]: + unpermitted_keys = override_config.keys( + ) - allowed_block_overrides.keys() + if len(unpermitted_keys): + raise KeyError(f'Overriding {unpermitted_keys} is not supported.') + + new_block_args = override_config | block_args + common_keys = override_config.keys() & block_args.keys() + for k in common_keys: + if type(override_config[k]) != type(block_args[k]): + raise ValueError( + f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.', + ) + if isinstance(override_config[k], dict): + new_block_args[k] = MPTModel._override_block_args( + block_args[k], + override_config[k], + allowed_block_overrides[k], + ) + else: + new_block_args[k] = override_config[k] + return new_block_args + def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: @@ -580,8 +805,9 @@ def forward( 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.', ) - elif (self.attn_uses_sequence_id is - False) and (sequence_id is not None): + elif ( + self.attn_uses_sequence_id is False and sequence_id is not None + ): warnings.warn( 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + @@ -701,7 +927,9 @@ def forward( # initialize the past key values cache if it should be used presents = () if use_cache else None - if use_cache and past_key_values is None: + if ( + use_cache or len(self.kv_cache_layers) > 0 + ) and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore @@ -718,13 +946,27 @@ def forward( attention_mask, ) + layer_kv_cache_dict = {} for b_idx, block in enumerate(self.blocks): + attn_block = block.norm_attn_norm.attn if self.blocks_fuse_norm_attn_norm else block.attn + if attn_block.reuse_kv_layer_idx is not None: + if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: + raise KeyError( + f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', + ) + prev_layer_key_value = layer_kv_cache_dict[ + attn_block.reuse_kv_layer_idx] + else: + prev_layer_key_value = None if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) past_key_value = ( past_key_values[b_idx] if past_key_values is not None else None ) + extra_kwargs = {} + if prev_layer_key_value is not None: + extra_kwargs['prev_layer_key_value'] = prev_layer_key_value x, attn_weights, present = block( x, past_key_value=past_key_value, @@ -735,9 +977,15 @@ def forward( output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) if presents is not None: presents += (present,) + if b_idx in self.kv_cache_layers: + layer_kv_cache_dict[b_idx] = [ + present[0][:, past_position:], + present[1][:, past_position:], + ] if output_attentions: assert all_self_attns is not None # pyright @@ -1074,6 +1322,40 @@ def _reorder_cache( return reordered_past +def get_targets(labels: torch.Tensor) -> torch.Tensor: + targets = torch.roll(labels, shifts=-1) + targets[:, -1] = -100 + return targets + + +def compute_loss_from_logits( + outputs: CausalLMOutputWithPast, + shift_labels: bool, + labels: torch.Tensor, + loss_fn: nn.Module, + sample_weighing_factor: Optional[torch.Tensor] = None, +) -> torch.Tensor: + targets = get_targets(labels) if shift_labels else labels + + losses = loss_fn( + outputs.logits.view(-1, outputs.logits.size(-1)), + targets.view(-1), + ) + + if torch.all(targets == loss_fn.ignore_index): + loss = losses.sum() + else: + loss = losses.sum() / (targets != loss_fn.ignore_index).sum() + if sample_weighing_factor is not None: + if sample_weighing_factor.shape[0] > 1: + raise ValueError( + 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', + ) + loss = loss * sample_weighing_factor[0].item() + + return loss + + class ComposerMPTCausalLM(HuggingFaceModel): def __init__( @@ -1092,7 +1374,7 @@ def __init__( additional_train_metrics = additional_train_metrics or [] - model = self.model_class(self.config_class(**kwargs),) + model = self.model_class(self.config_class(**kwargs)) use_train_metrics = use_train_metrics train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics @@ -1151,9 +1433,7 @@ def config_class(self) -> Type[MPTConfig]: return MPTConfig def get_targets(self, batch: Mapping) -> torch.Tensor: - targets = torch.roll(batch['labels'], shifts=-1) - targets[:, -1] = -100 - return targets + return get_targets(batch['labels']) def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: @@ -1174,27 +1454,14 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - if self.shift_labels: - targets = self.get_targets(batch) - else: - targets = batch['labels'] - - losses = self.loss_fn( - outputs.logits.view(-1, outputs.logits.size(-1)), - targets.view(-1), + loss = compute_loss_from_logits( + outputs, + self.shift_labels, + batch['labels'], + self.loss_fn, + batch.get('sample_weighing_factor', None), ) - if torch.all(targets == self.loss_fn.ignore_index): - loss = losses.sum() - else: - loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if 'sample_weighing_factor' in batch: - if batch['sample_weighing_factor'].shape[0] > 1: - raise ValueError( - 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', - ) - loss = loss * batch['sample_weighing_factor'][0].item() - if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors @@ -1209,7 +1476,6 @@ def loss(self, outputs: CausalLMOutputWithPast, 'loss': loss, 'lbl': lbl, } - return loss @cached_property @@ -1227,6 +1493,12 @@ def flops_per_batch(self, batch: Mapping): # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass + if self.model.config.block_overrides is not None: + warnings.warn( + 'Warning, flop computation is not supported when using block overrides. Returning 0 flops per batch.', + ) + return 0 + bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params params_flops_per_token = 2 * params diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 2b6fc2f7c7..c272a52dd4 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -15,6 +15,7 @@ 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, + 'fused_qkv': True, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 963c596e76..40342f2ddb 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -177,9 +177,9 @@ def config_megablocks_moe_args( lbl_process_group = create_set_process_group(lbl_process_group) else: lbl_process_group = None - elif lbl_process_group is not None: + elif not isinstance(lbl_process_group, distributed.ProcessGroup): raise ValueError( - f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .', + f'Unknown {lbl_process_group=}. Options are: none | a process group | ``expert_group`` | ``global_group`` | .', ) ffn_config['lbl_process_group'] = lbl_process_group diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 0c8e64b759..3f0163ff01 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -6,8 +6,10 @@ from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim import ComposerScheduler +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader +from torch.utils.data import Dataset from torchmetrics import Metric from transformers import PreTrainedTokenizerBase @@ -26,11 +28,17 @@ from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( - 'The loggers registry is used to register classes that implement the LoggerDestination interface. ' - + - 'These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers ' - + - 'will be constructed by directly passing along the specified kwargs to the constructor.' + """The loggers registry is used to register classes that implement the LoggerDestination interface. + + These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers + will be constructed by directly passing along the specified kwargs to the constructor. See loggers/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the LoggerDestination constructor. + + Returns: + LoggerDestination: The logger destination. + """ ) loggers = create_registry( 'llmfoundry', @@ -41,11 +49,17 @@ ) _callbacks_description = ( - 'The callbacks registry is used to register classes that implement the Callback interface. ' - + - 'These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. ' - + - 'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.' + """The callbacks registry is used to register classes that implement the Callback interface. + + These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. + The callbacks will be constructed by directly passing along the specified kwargs to the constructor. See callbacks/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Callback constructor. + + Returns: + Callback: The callback. + """ ) callbacks = create_registry( 'llmfoundry', @@ -56,22 +70,40 @@ ) _callbacks_with_config_description = ( - 'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. ' - + - 'These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.' + """The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. + + These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor. + See callbacks/ for examples. + + Args: + config (DictConfig): The training config. + kwargs (Dict[str, Any]): The kwargs to pass to the Callback constructor. + + Returns: + Callback: The callback. + """ ) callbacks_with_config = create_registry( - 'llm_foundry.callbacks_with_config', + 'llmfoundry', + 'callbacks_with_config', generic_type=Type[CallbackWithConfig], entry_points=True, description=_callbacks_with_config_description, ) _optimizers_description = ( - 'The optimizers registry is used to register classes that implement the Optimizer interface. ' - + - 'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the ' - + 'specified kwargs to the constructor, along with the model parameters.' + """The optimizers registry is used to register classes that implement the Optimizer interface. + + The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the + specified kwargs to the constructor, along with the model parameters. See optim/ for examples. + + Args: + params (Iterable[torch.nn.Parameter]): The model parameters. + kwargs (Dict[str, Any]): The kwargs to pass to the Optimizer constructor. + + Returns: + Optimizer: The optimizer. + """ ) optimizers = create_registry( 'llmfoundry', @@ -82,10 +114,17 @@ ) _algorithms_description = ( - 'The algorithms registry is used to register classes that implement the Algorithm interface. ' - + - 'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the ' - + 'specified kwargs to the constructor.' + """The algorithms registry is used to register classes that implement the Algorithm interface. + + The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the + specified kwargs to the constructor. See algorithms/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Algorithm constructor. + + Returns: + Algorithm: The algorithm. + """ ) algorithms = create_registry( 'llmfoundry', @@ -96,10 +135,17 @@ ) _schedulers_description = ( - 'The schedulers registry is used to register classes that implement the ComposerScheduler interface. ' - + - 'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the ' - + 'specified kwargs to the constructor.' + """The schedulers registry is used to register classes that implement the ComposerScheduler interface. + + The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the + specified kwargs to the constructor. See optim/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the ComposerScheduler constructor. + + Returns: + ComposerScheduler: The scheduler. + """ ) schedulers = create_registry( 'llmfoundry', @@ -109,12 +155,32 @@ description=_schedulers_description, ) -_models_description = ( - 'The models registry is used to register classes that implement the ComposerModel interface. ' - + - 'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. ' +_tokenizers_description = ( + 'The tokenizers registry is used to register tokenizers that implement the transformers.PreTrainedTokenizerBase interface. ' + - 'Note: This will soon be updated to take in named kwargs instead of a config directly.' + 'The tokenizer will be passed to the build_dataloader() and build_composer_model() methods in train.py.' +) +tokenizers = create_registry( + 'llmfoundry', + 'tokenizers', + generic_type=Type[PreTrainedTokenizerBase], + entry_points=True, + description=_tokenizers_description, +) + +_models_description = ( + """The models registry is used to register classes that implement the ComposerModel interface. + + The model constructor should accept a PreTrainedTokenizerBase named `tokenizer`, and the rest of its constructor kwargs. + See models/ for examples. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer. + kwargs (Dict[str, Any]): The kwargs to pass to the Composer + + Returns: + ComposerModel: The model. + """ ) models = create_registry( 'llmfoundry', @@ -125,9 +191,19 @@ ) _dataloaders_description = ( - 'The dataloaders registry is used to register functions that create a DataSpec. The function should take ' - + - 'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.' + """The dataloaders registry is used to register functions that create a DataSpec given a config. + + The function should take a PreTrainedTokenizerBase, a device batch size, and the rest of its constructor kwargs, + and return a DataSpec. See data/ for examples. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer + device_batch_size (Union[int, float]): The device batch size. + kwargs (Dict[str, Any]): The kwargs to pass to the builder function. + + Returns: + DataSpec: The dataspec. + """ ) dataloaders = create_registry( 'llmfoundry', @@ -140,14 +216,19 @@ ) _dataset_replication_validators_description = ( - """Validates the dataset replication args. + """The dataset_replication_validators registry is used to register functions that validate replication factor. + + The function should return the replication factor and the dataset device batch size. See data/ for examples. + Args: cfg (DictConfig): The dataloader config. tokenizer (PreTrainedTokenizerBase): The tokenizer device_batch_size (Union[int, float]): The device batch size. + Returns: replication_factor (int): The replication factor for dataset. - dataset_batch_size (int): The dataset device batch size.""" + dataset_batch_size (int): The dataset device batch size. + """ ) dataset_replication_validators = create_registry( 'llmfoundry', @@ -160,14 +241,19 @@ ) _collators_description = ( - """Returns the data collator. + """The collators registry is used to register functions that create the collate function for the DataLoader. + + See data/ for examples. + Args: cfg (DictConfig): The dataloader config. tokenizer (PreTrainedTokenizerBase): The tokenizer dataset_batch_size (Union[int, float]): The dataset device batch size. + Returns: collate_fn (Any): The collate function. - dataloader_batch_size (int): The batch size for dataloader. In case of packing, this might be the packing ratio times the dataset device batch size.""" + dataloader_batch_size (int): The batch size for dataloader. In case of packing, this might be the packing ratio times the dataset device batch size. + """ ) collators = create_registry( 'llmfoundry', @@ -179,12 +265,17 @@ ) _data_specs_description = ( - """Returns the get_data_spec function. + """The data_specs registry is used to register functions that create a DataSpec given a dataloader. + + See data/ for examples. + Args: dl (Union[Iterable, TorchDataloader): The dataloader. dataset_cfg (DictConfig): The dataset config. + Returns: - dataspec (DataSpec): The dataspec.""" + dataspec (DataSpec): The dataspec. + """ ) data_specs = create_registry( 'llmfoundry', @@ -196,7 +287,17 @@ ) _metrics_description = ( - 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.' + """The metrics registry is used to register classes that implement the torchmetrics.Metric interface. + + The metric will be passed to the metrics arg of the Trainer. The metric will be constructed by directly passing along the + specified kwargs to the constructor. See metrics/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Metric constructor. + + Returns: + Metric: The metric. + """ ) metrics = create_registry( 'llmfoundry', @@ -206,6 +307,88 @@ description=_metrics_description, ) +_icl_datasets_description = ( + """The ICL datasets registry is used to register classes that implement the InContextLearningDataset interface. + + The dataset will be constructed along with an Evaluator. The dataset will be constructed by directly passing along the + specified kwargs to the constructor. See eval/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Dataset constructor. + + Returns: + InContextLearningDataset: The dataset. + """ +) +icl_datasets = create_registry( + 'llmfoundry', + 'icl_datasets', + # TODO: Change type from Dataset to + # llmfoundry.eval.InContextLearningDataset. + # Using ICL dataset here introduces a circular import dependency between + # the registry and eval packages right now, thus needs some refactoring. + generic_type=Type[Dataset], + entry_points=True, + description=_icl_datasets_description, +) + +_config_transforms_description = ( + """The config_transforms registry is used to register functions that transform the training config + + The config will be transformed before it is used anywhere else. Note: By default ALL registered transforms will be applied to the train config + and NONE to the eval config. Each transform should return the modified config. See utils/config_utils.py for examples. + + Args: + cfg (Dict[str, Any]): The training config. + + Returns: + cfg (Dict[str, Any]): The modified training config. + """ +) +config_transforms = create_registry( + 'llmfoundry', + 'config_transforms', + generic_type=Callable[[Dict[str, Any]], Dict[str, Any]], + entry_points=True, + description=_config_transforms_description, +) + +_load_planners_description = ( + """The load_planners registry is used to register classes that implement the LoadPlanner interface. + + The LoadPlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to load distributed checkpoints. + + Returns: + LoadPlanner: The load planner. + """ +) + +load_planners = create_registry( + 'llmfoundry', + 'load_planners', + generic_type=Type[LoadPlanner], + entry_points=True, + description=_load_planners_description, +) + +_save_planners_description = ( + """The save_planners registry is used to register classes that implement the SavePlanner interface. + + The savePlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to save distributed checkpoints. + + Returns: + SavePlanner: The save planner. + """ +) + +save_planners = create_registry( + 'llmfoundry', + 'save_planners', + generic_type=Type[SavePlanner], + entry_points=True, + description=_save_planners_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -213,6 +396,7 @@ 'optimizers', 'algorithms', 'schedulers', + 'tokenizers', 'models', 'dataset_replication_validators', 'collators', @@ -228,4 +412,8 @@ 'attention_classes', 'attention_implementations', 'fcs', + 'icl_datasets', + 'config_transforms', + 'load_planners', + 'save_planners', ] diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index 1703ed8862..d37c12a555 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -1,8 +1,11 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) + __all__ = [ 'TiktokenTokenizerWrapper', ] diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index f087664344..fd0fc5948a 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -90,6 +90,7 @@ def __init__( errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. Defaults to `"replace"`. + kwargs (Any): Other relevant keyword arguments. """ try: import tiktoken diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index dd43efcdd7..87a08a999d 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -1,10 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.registry import config_transforms from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, + build_eval_loaders, build_evaluators, build_icl_data_and_gauntlet, build_icl_evaluators, @@ -59,9 +62,16 @@ experimental_function, ) +config_transforms.register( + 'update_batch_size_info', + func=update_batch_size_info, +) + __all__ = [ + 'add_metrics_to_eval_loaders', 'build_algorithm', 'build_callback', + 'build_eval_loaders', 'build_evaluators', 'build_icl_data_and_gauntlet', 'build_icl_evaluators', diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 73eb026d98..a1d84601b3 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import functools import logging import os @@ -26,6 +27,7 @@ from composer.utils import dist from omegaconf import DictConfig from omegaconf import OmegaConf as om +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -33,9 +35,9 @@ from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.eval.datasets.in_context_learning_evaluation import \ - get_icl_task_dataloader -from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.eval.datasets.in_context_learning_evaluation import ( + get_icl_task_dataloader, +) from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry @@ -62,7 +64,7 @@ def build_evaluators( eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], *, tokenizer: PreTrainedTokenizerBase, - device_eval_batch_size: int, + device_eval_batch_size: Union[int, float], icl_seq_len: int, icl_subset_num_batches: Optional[int], ) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: @@ -78,6 +80,10 @@ def build_evaluators( logger_keys = [] eval_gauntlet_callback = None if icl_tasks_config is not None: + if not isinstance(device_eval_batch_size, int): + raise ValueError( + 'device_eval_batch_size should be an int for icl tasks.', + ) icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet( icl_tasks_config, eval_gauntlet_config, @@ -94,7 +100,7 @@ def build_evaluators( def build_eval_loaders( eval_loader_config: Union[Dict[str, Any], List[Dict[str, Any]]], tokenizer: PreTrainedTokenizerBase, - device_eval_batch_size: int, + device_eval_batch_size: Union[int, float], ) -> List[Evaluator]: evaluators: List[Evaluator] = [] if isinstance(eval_loader_config, list): @@ -181,6 +187,46 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb +def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner: + """Builds a load planner from the registry. + + Args: + name (str): Name of the load planner to build. + kwargs (Any): Other relevant keyword arguments. + + Returns: + LoadPlanner: The load planner. + """ + return construct_from_registry( + name=name, + registry=registry.load_planners, + partial_function=True, + pre_validation_function=LoadPlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + +def build_save_planner(name: str, **kwargs: Any) -> SavePlanner: + """Builds a save planner from the registry. + + Args: + name (str): Name of the save planner to build. + kwargs (Any): Other relevant keyword arguments. + + Returns: + savePlanner: The save planner. + """ + return construct_from_registry( + name=name, + registry=registry.save_planners, + partial_function=True, + pre_validation_function=SavePlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + def build_composer_model( name: str, cfg: Dict[str, Any], @@ -248,7 +294,7 @@ def build_callback( raise ValueError( f'`train_config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.', ) - kwargs['train_config'] = train_config + kwargs['train_config'] = copy.deepcopy(train_config) registry_to_use = registry.callbacks_with_config return construct_from_registry( @@ -461,8 +507,15 @@ def build_tokenizer( with dist.local_rank_zero_download_and_wait(signal_file_path): pass - if tokenizer_name.startswith('tiktoken'): - tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) + if tokenizer_name in registry.tokenizers: + tokenizer = construct_from_registry( + name=tokenizer_name, + registry=registry.tokenizers, + partial_function=True, + pre_validation_function=PreTrainedTokenizerBase, + post_validation_function=None, + kwargs=tokenizer_kwargs, + ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -545,22 +598,10 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg["icl_task_type"]}.', ) - if 'prompt_string' not in icl_cfg: - icl_cfg['prompt_string'] = '' - if 'example_delimiter' not in icl_cfg: - icl_cfg['example_delimiter'] = '\n' - if 'continuation_delimiter' not in icl_cfg: - icl_cfg['continuation_delimiter'] = ' ' if 'max_seq_len' not in icl_cfg: icl_cfg['max_seq_len'] = default_max_seq_len if 'batch_size' not in icl_cfg: icl_cfg['batch_size'] = default_batch_size - if 'pass_at_k' not in icl_cfg: - icl_cfg['pass_at_k'] = 1 - if 'fewshot_random_seed' not in icl_cfg: - icl_cfg['fewshot_random_seed'] = 1234 - if 'generations_per_sample' not in icl_cfg: - icl_cfg['generations_per_sample'] = 1 if 'num_beams' in icl_cfg: raise ValueError( @@ -579,6 +620,7 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): pad_tok_id = tokenizer.eos_token_id else: pad_tok_id = tokenizer.pad_token_id + label = f'{icl_cfg["label"]}/{num_fewshot}-shot' metric_names = list(icl_cfg['metric_names']) # TODO: fix Composer bug when copying local paths and destination exists @@ -589,38 +631,51 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): hf_parsing_map = icl_cfg.get('hf_parsing_map', {}) hf_loading_vars = icl_cfg.get('hf_loading_vars', {}) - early_stopping_criteria = icl_cfg.get( 'early_stopping_criteria', - None, + [], ) + # TODO: fix manual removal of non-constructor fields + icl_constructor_kwargs = copy.deepcopy(icl_cfg) + icl_constructor_kwargs.pop('label', None) + icl_constructor_kwargs.pop('metric_names', None) + icl_constructor_kwargs.pop('icl_task_type', None) + icl_constructor_kwargs.pop('batch_size', None) + icl_constructor_kwargs.pop('has_categories', None) + + # Add custom constructor arguments + icl_constructor_kwargs['pad_tok_id'] = pad_tok_id + icl_constructor_kwargs['num_fewshot'] = num_fewshot + + # Support backwards compatibility for the naming of "prelimiter" as "question_prelimiter" + if 'question_prelimiter' in icl_constructor_kwargs: + if 'prelimiter' in icl_constructor_kwargs: + raise ValueError( + 'Both "question_prelimiter" and "prelimiter" are specified in the ICL task config. ' + + + 'Please only specify one of them, as they map to the same argument.', + ) + else: + icl_constructor_kwargs['prelimiter' + ] = icl_constructor_kwargs.pop( + 'question_prelimiter', + ) + assert early_stopping_criteria is None or isinstance( early_stopping_criteria, list, ) + dataloaders = get_icl_task_dataloader( - icl_cfg['icl_task_type'], - icl_cfg['dataset_uri'], - tokenizer, + icl_task_type=icl_cfg['icl_task_type'], + dataset_uri=icl_cfg['dataset_uri'], + tokenizer=tokenizer, batch_size=icl_cfg['batch_size'], - max_seq_len=icl_cfg['max_seq_len'], - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=icl_cfg['prompt_string'], - example_delimiter=icl_cfg['example_delimiter'], hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=icl_cfg['continuation_delimiter'], - question_prelimiter=icl_cfg.get('question_prelimiter', ''), - destination_path=destination_path, - fewshot_random_seed=icl_cfg['fewshot_random_seed'], - pass_at_k=icl_cfg['pass_at_k'], - generations_per_sample=icl_cfg['generations_per_sample'], has_categories=icl_cfg.get('has_categories', False), - cot_delimiter=icl_cfg.get('cot_delimiter', ''), - generation_kwargs=icl_cfg.get('generation_kwargs', {}), - early_stopping_criteria=early_stopping_criteria, - do_normalization=icl_cfg.get('do_normalization', True), + destination_path=destination_path, + kwargs=icl_constructor_kwargs, ) if 'has_categories' in icl_cfg and icl_cfg[ 'has_categories'] and isinstance(dataloaders, dict): diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 905afd6edb..5c65a7475e 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -177,6 +177,7 @@ def _convert_weight_to_ft_each( tensor_name (str): Name of the weight tensor. Used in naming the output file. config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters. data (np.ndarray): Tensor data in np.ndarray format. + np_weight_data_type (np.dtype): Data type of the numpy array `data`. Returns: None: Writes to a file in `save_dir`. File name is based on the `tensor_name` diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index b6a5acf6d9..eb54fabc3d 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -30,6 +30,7 @@ from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights +from llmfoundry.registry import config_transforms log = logging.getLogger(__name__) @@ -48,7 +49,7 @@ class EvalConfig: # Eval Config required parameters: models: List[Dict[str, Any]] = MISSING max_seq_len: int = MISSING - device_eval_batch_size: int = MISSING + device_eval_batch_size: Union[int, float] = MISSING # Eval Config optional parameters: code_paths: Optional[List[str]] = None @@ -67,6 +68,7 @@ class EvalConfig: # Logging parameters python_log_level: Optional[str] = 'debug' loggers: Optional[Dict[str, Any]] = None + console_log_interval: Union[int, str] = '1ba' log_config: bool = True # Model/run parameters @@ -99,12 +101,14 @@ class TrainConfig: optimizer: Dict[str, Any] = MISSING scheduler: Dict[str, Any] = MISSING train_loader: Dict[str, Any] = MISSING - device_train_batch_size: int = MISSING - device_eval_batch_size: int = MISSING + device_train_batch_size: Union[int, float] = MISSING + device_eval_batch_size: Union[int, float] = MISSING max_duration: Union[int, str] = MISSING eval_interval: Union[int, str] = MISSING max_seq_len: int = MISSING - seed: int = MISSING + + # Seed + seed: int = 17 # Precision precision: str = 'amp_bf16' @@ -114,7 +118,7 @@ class TrainConfig: # Cuda allocation configuration max_split_size_mb: Optional[int] = None - expandable_segments: bool = False + expandable_segments: bool = True cuda_load_lazy: bool = False # Distributed training parameters @@ -157,10 +161,13 @@ class TrainConfig: load_strict_model_weights: bool = True load_ignore_keys: Optional[List[str]] = None save_ignore_keys: Optional[List[str]] = None + only_hf_checkpoint: bool = False + only_composer_checkpoint: bool = False # Dataloader - device_train_microbatch_size: Union[str, int] = 'auto' + device_train_microbatch_size: Union[str, int, float] = 'auto' global_train_batch_size: Optional[int] = None + spin_dataloaders: bool = True # Eval dataloader eval_subset_num_batches: int = -1 @@ -169,6 +176,7 @@ class TrainConfig: # Metadata metadata: Optional[Dict[str, Any]] = None + flatten_metadata: bool = True run_name: Optional[str] = None # Resumption @@ -180,6 +188,10 @@ class TrainConfig: # Variables to ignore variables: Optional[Dict[str, Any]] = None + # Fields created by `update_batch_size_info` + n_gpus: int = MISSING + device_train_grad_accum: Union[str, int] = MISSING + TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)} @@ -233,16 +245,63 @@ def to_container( T = TypeVar('T') +def apply_transforms_to_config( + cfg: Dict[str, Any], + transforms: Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]], + List[str], str]], +) -> Dict[str, Any]: + """Applies a list of transforms to a config. + + Args: + cfg (Dict[str, Any]): The config to transform. + transforms (Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]], List[str], str]]): A list of + transform functions or strings representing transform functions to apply to the config. If a single string + with the value ``all`` is provided, all registered transforms will be applied. + + Returns: + Dict[str, Any]: The transformed config. + """ + if transforms is None or ( + isinstance(transforms, list) and len(transforms) == 0 + ): + return cfg + + transform_functions = [] + if isinstance(transforms, list): + for transform in transforms: + if isinstance(transform, str): + transform_functions.append(config_transforms.get(transform)) + elif callable(transform): + transform_functions.append(transform) + else: + raise ValueError( + f'Invalid transform: {transform}. Must be a string or callable.', + ) + elif isinstance(transforms, str) and transforms == 'all': + transform_functions = [ + config_transforms.get(transform) + for transform in config_transforms.get_all() + ] + else: + raise ValueError( + f'Invalid transforms: {transforms}. Must be a list of strings or callables, or ``all``.', + ) + + for transform in transform_functions: + cfg = transform(cfg) + + return cfg + + def make_dataclass_and_log_config( cfg: DictConfig, dataclass_constructor: Callable[..., T], dataclass_fields: Set[str], - transforms: Optional[List[Callable[[Dict[str, Any]], Dict[str, - Any]]]] = None, + transforms: Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]], + List[str], str]] = None, icl_tasks_required: bool = False, ) -> Tuple[Dict[str, Any], T]: """Converts a DictConfig to a dataclass and creates a logged config.""" - # Resolve all interpolation variables as early as possible unstructured_config = om.to_container(cfg, resolve=True) assert isinstance(unstructured_config, dict) assert all(isinstance(k, str) for k in unstructured_config.keys()) @@ -273,14 +332,14 @@ def make_dataclass_and_log_config( 'icl_tasks must be specified in the config', ) - # Create copy of config for logging - logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config) - # Apply transforms to the unstructured config before constructing dataclass - for transform in transforms or []: - unstructured_config = transform(unstructured_config) + unstructured_config = apply_transforms_to_config( + unstructured_config, + transforms, + ) - logged_cfg.update(unstructured_config, merge=True) + # Create copy of config for logging + logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config) arg_config_keys = set(unstructured_config.keys()) extraneous_keys = set.difference(arg_config_keys, dataclass_fields) @@ -288,12 +347,10 @@ def make_dataclass_and_log_config( if 'variables' not in unstructured_config: unstructured_config['variables'] = {} - for key in extraneous_keys: - warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Interpreting {key} as a variable for logging purposes. Top-level variables are deprecated and will not be supported in future releases. Please place any variables under the `variables` key.', - category=DeprecationWarning, + if len(extraneous_keys) > 0: + raise ValueError( + f'Unused parameters {sorted(extraneous_keys)} found in cfg. Please check your yaml to ensure these parameters are necessary. Please place any variables under the `variables` key.', ) - unstructured_config['variables'][key] = unstructured_config.pop(key) dataclass_dict_config: DictConfig = om.structured( dataclass_constructor(**unstructured_config), @@ -365,20 +422,20 @@ def calculate_batch_size_info( data_replication_degree: int = 1, ) -> Tuple[Union[int, float], Union[int, float, Literal['auto']], Union[ int, Literal['auto']]]: - if dist.get_world_size() % data_replication_degree != 0: + + world_size = dist.get_world_size() + if world_size % data_replication_degree != 0: raise ValueError( - f'World size {dist.get_world_size()} is not divisible by data replication degree {data_replication_degree}.', + f'World size {world_size} is not divisible by data replication degree {data_replication_degree}.', ) - if global_batch_size % ( - dist.get_world_size() // data_replication_degree - ) != 0: + if global_batch_size % (world_size // data_replication_degree) != 0: raise ValueError( - f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + f'Global batchsize {global_batch_size} is not divisible by {(world_size // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' - + f'to be divisible by world size, {dist.get_world_size()}.', + + f'to be divisible by world size, {world_size}.', ) - device_batch_size = global_batch_size / dist.get_world_size() + device_batch_size = global_batch_size / world_size if device_batch_size == round(device_batch_size): device_batch_size = round(device_batch_size) if device_microbatch_size == 'auto': @@ -399,14 +456,23 @@ def calculate_batch_size_info( return device_batch_size, device_microbatch_size, device_grad_accum -# Coming soon: this conversion math will be done inside Composer Trainer -def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: - data_replication_degree = 1 - device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( - cfg['global_train_batch_size'], - cfg['device_train_microbatch_size'], - data_replication_degree=data_replication_degree, - ) +def update_config_with_batch_size_info( + cfg: Dict[str, Any], + device_train_batch_size: Union[int, float], + device_train_microbatch_size: Union[int, float, Literal['auto']], + device_train_grad_accum: Union[int, Literal['auto']], +) -> Dict[str, Any]: + """Update the config with batch size information. + + Args: + cfg (Dict[str, Any]): The config to update. + device_train_batch_size (Union[int, float]): The batch size of the training dataset for each device. + device_train_microbatch_size (Union[int, float, Literal['auto']]): The microbatch size of the training dataset for each device. + device_train_grad_accum (Union[int, Literal['auto']]): The gradient accumulation settings for each device. + + Returns: + Dict[str, Any]: The updated config. + """ cfg['n_gpus'] = dist.get_world_size() cfg['device_train_batch_size'] = device_train_batch_size cfg['device_train_microbatch_size'] = device_train_microbatch_size @@ -421,6 +487,22 @@ def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: return cfg +def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: + data_replication_degree = 1 + device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( + cfg['global_train_batch_size'], + cfg['device_train_microbatch_size'], + data_replication_degree=data_replication_degree, + ) + cfg = update_config_with_batch_size_info( + cfg, + device_train_batch_size, + device_train_microbatch_size, + device_train_grad_accum, + ) + return cfg + + def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors @@ -451,7 +533,6 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): fsdp_config['sync_module_states'] = True # Set defaults for mixed initialization - fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) # Set ffn_config.device_mesh to fsdp_config.device_mesh @@ -591,8 +672,28 @@ def _process_data_source( # Check for HF path elif 'hf_name' in dataset and dataset['hf_name']: hf_path = dataset['hf_name'] - backend, _, _ = parse_uri(hf_path) - if backend: + backend, _, uc_path = parse_uri(hf_path) + unsupported_file = True + if backend == 'dbfs': + assert cfg_split + from llmfoundry.data.finetuning.tasks import SUPPORTED_EXTENSIONS + possible_files = [ + f'{cfg_split}{ext}' for ext in SUPPORTED_EXTENSIONS + ] + for file in possible_files: + path = os.path.join(uc_path, file) + # Ensure path starts with '/' + if not path.startswith('/'): + path = '/' + path + if _verify_uc_path(path): + data_paths.append(('uc_volume', path, true_split)) + unsupported_file = False + break + if unsupported_file: + log.warning( + f'{hf_path} does not contain a supported file extension.', + ) + elif backend: hf_path = os.path.join(hf_path, cfg_split) if cfg_split else hf_path data_paths.append((backend, hf_path, true_split)) elif os.path.exists(hf_path): @@ -663,3 +764,94 @@ def log_dataset_uri(cfg: Dict[str, Any]) -> None: mlflow.log_input( mlflow.data.meta_dataset.MetaDataset(source, name=split), ) + + +def _verify_uc_path(path: str) -> bool: + """Verify a UC path exists. + + Args: + path (str): UnityCatalog path + Returns: + (bool): If path exists or not + """ + from databricks.sdk.errors.platform import NotFound, PermissionDenied + w = None + try: + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + except ImportError: + log.warning( + 'Cannot verify the path of `UCVolumeDatasetSource` because of missing' + \ + '`databricks-sdk`. Please install `databricks-sdk` via ' + \ + '`pip install -U databricks-sdk`. This does not block creating ' + \ + '`UCVolumeDatasetSource`, but your `UCVolumeDatasetSource` might be invalid.', + ) + return False + except Exception as e: + log.warning( + f'Error occured when attempting to connect with Databricks WorkspaceClient. ' + \ + f'Error details: {str(e)}. This does not block creating `UCVolumeDatasetSource`, ' + \ + f'but your `UCVolumeDatasetSource` might be invalid.', + ) + + if w: + try: + w.files.get_metadata(path) + except (NotFound, PermissionDenied): + try: + # Check if `self.path` points to a valid UC directory. + w.files.get_directory_metadata(path) + return True + except (NotFound, PermissionDenied): + # Neither file nor directory exists, we throw an exception. + return False + except Exception as e: + log.warning( + f'Error occured when verifying path of `UCVolumeDatasetSource`. ' + \ + f'Error details: {str(e)}. This does not block creating `UCVolumeDatasetSource`, ' + \ + f'but your `UCVolumeDatasetSource` might be invalid.', + ) + return False + + +def set_config_overrides( + config: PretrainedConfig, + config_overrides: Dict[str, Any], +): + # set config overrides + for k, v in config_overrides.items(): + if not hasattr(config, k): + raise ValueError( + f'config does not have attribute "{k}" to override ({k}: {v}).', + ) + + attr = getattr(config, k) + # attempt to disallow typos in nested configs + if isinstance(attr, Mapping): + extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] + if extra_keys: + raise ValueError( + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + + f'Expected (a subset of) keys: {list(attr.keys())}.', + ) + getattr(config, k).update(v) + # necessary case to allow for rope_scaling to be overriden in llama config + elif attr is None and isinstance(v, Mapping): + setattr(config, k, {}) + getattr(config, k).update(v) + elif isinstance(attr, PretrainedConfig): + if not isinstance(v, Mapping): + raise ValueError( + f'Expected a dictionary for config override {k}, but got {v}.', + ) + + for _k, _v in v.items(): + if not hasattr(attr, _k): + raise ValueError( + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', + ) + setattr(attr, _k, _v) + else: + setattr(config, k, v) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 76f378f8c6..140bf8540b 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -28,6 +28,7 @@ 'InputFolderMissingDataError', 'OutputFolderNotEmptyError', 'MisconfiguredHfDatasetError', + 'DatasetTooSmallError', 'RunTimeoutError', ] @@ -348,6 +349,14 @@ def __init__(self, dataset_name: str, split: str) -> None: super().__init__(message, dataset_name=dataset_name, split=split) +class DatasetTooSmallError(UserError): + """Error thrown when the dataset is too small to be processed.""" + + def __init__(self) -> None: + message = f'Your dataset is too small and produced no complete samples during preprocessing. Please provide more data.' + super().__init__(message) + + class RunTimeoutError(InternalError): """Error thrown when a run times out.""" diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 3f7b3a0f55..d3a2ddc0d4 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -243,10 +243,7 @@ def edit_files_for_hf_compatibility( # If the config file exists, the entrypoint files would be specified in the auto map entrypoint_files = set() if config_file_exists: - for key, value in config.get('auto_map', {}).items(): - # Only keep the modeling entrypoints, e.g. AutoModelForCausalLM - if 'model' not in key.lower(): - continue + for value in config.get('auto_map', {}).values(): split_path = value.split('.') if len(split_path) > 1: entrypoint_files.add(split_path[0] + '.py') @@ -279,12 +276,21 @@ def edit_files_for_hf_compatibility( os.path.splitext(os.path.basename(f))[0] for f in files_processed_and_queued } + # Filter out __init__ files + all_relative_imports = { + relative_import for relative_import in all_relative_imports + if relative_import not in {'__init__', 'modeling_mpt'} + } for entrypoint in entrypoint_files: + file_path = os.path.join(folder, entrypoint) + if not os.path.exists(file_path): + continue existing_relative_imports = get_all_relative_imports( os.path.join(folder, entrypoint), ) - # Add in self so we don't create a circular import - existing_relative_imports.add(os.path.splitext(entrypoint)[0]) + # Add in all entrypoints so we don't create a circular import + for sub_entrypoint in entrypoint_files: + existing_relative_imports.add(os.path.splitext(sub_entrypoint)[0]) missing_relative_imports = all_relative_imports - existing_relative_imports add_relative_imports( os.path.join(folder, entrypoint), diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index c11a47929f..9609982fda 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -69,13 +69,13 @@ def download_from_hf_hub( Safetensors weights will be downloaded unless `prefer_safetensors` is set to False. Args: - repo_id (str): The Hugging Face Hub repo ID. + model (str): The Hugging Face Hub repo ID. save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. tokenizer_only (bool): If true, only download tokenizer files. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the - `HUGGING_FACE_HUB_TOKEN` environment variable. + `HF_TOKEN` environment variable. Raises: RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized. @@ -157,7 +157,7 @@ def _recursive_download( Args: session: A requests.Session through which to make requests to the remote server. - url (str): The base URL where the files are located. + base_url (str): The base URL where the files are located. path (str): The path from the base URL to the files to download. The full URL for the download is equal to '/'. save_dir (str): The directory to save downloaded files to. diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 3ea7cc58a7..f96e72b3a2 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -127,6 +127,7 @@ def construct_from_registry( before constructing the item to return. This should throw an exception if validation fails. Defaults to None. post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after constructing the item to return. This should throw an exception if validation fails. Defaults to None. + kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. Raises: ValueError: If the validation functions failed or the registered item is invalid @@ -176,6 +177,7 @@ def import_file(loc: Union[str, Path]) -> ModuleType: """Import module from a file. Used to run arbitrary python code. + Args: name (str): Name of module to load. loc (str / Path): Path to the file. diff --git a/mcli/mcli-1b-eval.yaml b/mcli/mcli-1b-eval.yaml index fc72bac974..4bfa301f8e 100644 --- a/mcli/mcli-1b-eval.yaml +++ b/mcli/mcli-1b-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -9,7 +9,7 @@ integrations: command: | cd llm-foundry/scripts/ composer eval/eval.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest name: mpt-1b-eval compute: diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index 512ddc90c8..2dc83d36a9 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -17,7 +17,7 @@ command: | --out_root ./my-copy-c4 --splits train_small val_small \ --concat_tokens 8192 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest name: mpt-1b-ctx-8k-gpus-8 compute: diff --git a/mcli/mcli-1b.yaml b/mcli/mcli-1b.yaml index 9850860358..69b2295011 100644 --- a/mcli/mcli-1b.yaml +++ b/mcli/mcli-1b.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -21,7 +21,7 @@ command: | eval_loader.dataset.split=val_small \ max_duration=100ba \ eval_interval=0 -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest name: mpt-1b-gpus-8 compute: diff --git a/mcli/mcli-benchmark-mpt.yaml b/mcli/mcli-benchmark-mpt.yaml index a4b3f52ba7..7a3ea2cbe9 100644 --- a/mcli/mcli-benchmark-mpt.yaml +++ b/mcli/mcli-benchmark-mpt.yaml @@ -6,12 +6,12 @@ compute: # cluster: TODO # Name of the cluster to use for this run # gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] diff --git a/mcli/mcli-convert-composer-to-hf.yaml b/mcli/mcli-convert-composer-to-hf.yaml index bebdf42926..fefaf8e1a3 100644 --- a/mcli/mcli-convert-composer-to-hf.yaml +++ b/mcli/mcli-convert-composer-to-hf.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo @@ -13,7 +13,7 @@ command: | --hf_output_path s3://bucket/folder/hf/ \ --output_precision bf16 \ -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest name: convert-composer-hf compute: diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index 3e24bba9ae..e58d42483a 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -16,7 +16,7 @@ gpu_num: 8 # gpu_type: # cluster: # replace with your cluster here! -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: diff --git a/mcli/mcli-hf-generate.yaml b/mcli/mcli-hf-generate.yaml index c3bf6d48cc..02c49d84c3 100644 --- a/mcli/mcli-hf-generate.yaml +++ b/mcli/mcli-hf-generate.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -35,7 +35,7 @@ command: | "Here's a quick recipe for baking chocolate chip cookies: Start by" \ "The best 5 cities to visit in Europe are" -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest name: hf-generate compute: diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 932d013442..47c163faf8 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -9,7 +9,7 @@ integrations: command: | cd llm-foundry/scripts composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest name: llama2-finetune compute: @@ -36,7 +36,7 @@ parameters: init_device: mixed pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf pretrained: true - # Note: you must have set the HUGGING_FACE_HUB_TOKEN environment variable and have access to the llama2 models + # Note: you must have set the HF_TOKEN environment variable and have access to the llama2 models use_auth_token: true use_flash_attention_2: true diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index 9a589cbf84..c372014165 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu,openai] ssh_clone: false # Should be true if using a private repo @@ -16,7 +16,7 @@ gpu_num: # gpu_type: # cluster: # replace with your cluster here! -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index b3ad09ca28..a4496503cd 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -1,5 +1,5 @@ name: c4-2k-pre-tokenized -image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.3.1_cu121-latest compute: gpus: 8 # Number of GPUs to use @@ -14,7 +14,7 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo diff --git a/pyproject.toml b/pyproject.toml index 53007cafaf..fdbabfff96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,23 +11,44 @@ skip = [ "env", "wandb", "runs", "build", "node_modules" ] include_trailing_comma = true split_on_trailing_comma = true +# Ruff global +[tool.ruff] +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] + +# Ruff linter [tool.ruff.lint] select = [ "C4", - # TODO port pydocstyle - # "D", # pydocstyle "LOG", "PERF", "PLE", "COM812", + "D", # pydocstyle ] -[tool.ruff] -exclude = [ - "build/**", - "docs/**", - "node_modules/**", + +extend-select = ["D404"] # pydocstyle + +ignore = [ + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", + "D400", + "D401", + "D415", ] +[tool.ruff.lint.pydocstyle] +convention = "google" + + # Coverage [tool.coverage.run] parallel = true @@ -79,7 +100,7 @@ reportMissingImports = "none" # Pytest [tool.pytest.ini_options] # By default, skip gpu tests -addopts = "--tb=short -m 'not gpu'" +addopts = "--tb=short -m 'not gpu' --color=yes" markers = [ # For distributed testing @@ -506,8 +527,3 @@ ignore_patterns = [ "wandb/**/*.py", "build/**/*.py", ] - -[tool.pydocstyle] -convention="google" -add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" -add_select="D404" diff --git a/scripts/data_prep/README.md b/scripts/data_prep/README.md index 7881298b2f..3601cc865f 100644 --- a/scripts/data_prep/README.md +++ b/scripts/data_prep/README.md @@ -35,6 +35,23 @@ python convert_dataset_json.py \ Where `--path` can be a single json file, or a folder containing json files. `--split` denotes the intended split (hf defaults to `train`). +### Raw text files + +Using the `convert_text_to_mds.py` script, we convert a [text file](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt) containing the complete works of William Shakespeare. + + +```bash +# Convert json dataset to StreamingDataset format +mkdir shakespeare && cd shakespeare +curl -O https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt +cd .. +python convert_text_to_mds.py \ + --output_folder my-copy-shakespeare \ + --input_folder shakespeare \ + --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b \ + --compression zstd +``` + ## Converting a finetuning dataset Using the `convert_finetuning_dataset.py` script you can run a command such as: diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index d7aaa52193..3b893868b2 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -2,28 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """Streaming dataset conversion scripts for C4 and The Pile.""" -import json -import os -import platform from argparse import ArgumentParser, Namespace -from dataclasses import dataclass, field -from enum import Enum -from typing import Dict, Iterable, Optional, Union -import datasets as hf_datasets -import psutil -from streaming import MDSWriter -from torch.utils.data import DataLoader, Dataset, IterableDataset -from tqdm import tqdm -from transformers import PreTrainedTokenizerBase - -from llmfoundry.data import ConcatTokensDataset, NoConcatDataset -from llmfoundry.utils.builders import build_tokenizer - - -class ConcatMode(Enum): - NO_CONCAT = 'NO_CONCAT' - CONCAT_TOKENS = 'CONCAT_TOKENS' +from llmfoundry.command_utils import convert_dataset_hf_from_args def parse_args() -> Namespace: @@ -62,394 +43,22 @@ def parse_args() -> Namespace: parser.add_argument('--num_workers', type=int, required=False, default=None) parsed = parser.parse_args() - - if parsed.tokenizer_kwargs is not None: - parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs) - else: - parsed.tokenizer_kwargs = {} - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - # Make sure we have needed concat options - if ( - parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None - ): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer', - ) - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -@dataclass -class DataSplitConstants: - hf_split: str - folder_split: str - raw_samples: Optional[int] - truncated_samples: Union[int, None] - - -@dataclass -class DatasetConstants: - chars_per_sample: int - chars_per_token: int - splits: Dict[str, DataSplitConstants] = field(default_factory=dict) - - def __iter__(self): - for v in self.splits.values(): - yield v - - -class TrainSmallConstants(DataSplitConstants): - - def __init__( - self, - hf_split: str = 'train', - folder_split: str = 'train_small', - raw_samples: int = 100000, - truncated_samples: int = 100000, - ): - super().__init__(hf_split, folder_split, raw_samples, truncated_samples) - - -class ValSmallConstants(DataSplitConstants): - - def __init__( - self, - hf_split: str = 'validation', - folder_split: str = 'val_small', - raw_samples: int = 10000, - truncated_samples: int = 10000, - ): - super().__init__(hf_split, folder_split, raw_samples, truncated_samples) - - -class ValXSmallConstants(DataSplitConstants): - - def __init__( - self, - hf_split: str = 'validation', - folder_split: str = 'val_xsmall', - raw_samples: int = 3000, - truncated_samples: int = 3000, - ): - super().__init__(hf_split, folder_split, raw_samples, truncated_samples) - - -pileconstants = DatasetConstants( - chars_per_sample=6212, # Computed over validation set - chars_per_token=4, # OpenAI estimate -) -pileconstants.splits['train'] = DataSplitConstants( - hf_split='train', - folder_split='train', - raw_samples=210607728, - truncated_samples=None, -) -pileconstants.splits['train_small'] = DataSplitConstants( - hf_split='train', - folder_split='train_small', - raw_samples=100000, - truncated_samples=100000, -) -pileconstants.splits['val'] = DataSplitConstants( - hf_split='validation', - folder_split='val', - raw_samples=214670, - truncated_samples=None, -) -pileconstants.splits['val_small'] = DataSplitConstants( - hf_split='validation', - folder_split='val_small', - raw_samples=10000, - truncated_samples=10000, -) -pileconstants.splits['val_xsmall'] = DataSplitConstants( - hf_split='validation', - folder_split='val_xsmall', - raw_samples=3000, - truncated_samples=3000, -) - -c4constants = DatasetConstants( - chars_per_sample=2163, # Computed over validation set - chars_per_token=4, # OpenAI estimate -) -c4constants.splits['train'] = DataSplitConstants( - hf_split='train', - folder_split='train', - raw_samples=364868892, - truncated_samples=None, -) -c4constants.splits['train_small'] = DataSplitConstants( - hf_split='train', - folder_split='train_small', - raw_samples=100000, - truncated_samples=100000, -) -c4constants.splits['val'] = DataSplitConstants( - hf_split='validation', - folder_split='val', - raw_samples=364608, - truncated_samples=None, -) -c4constants.splits['val_small'] = DataSplitConstants( - hf_split='validation', - folder_split='val_small', - raw_samples=10000, - truncated_samples=10000, -) -c4constants.splits['val_xsmall'] = DataSplitConstants( - hf_split='validation', - folder_split='val_xsmall', - raw_samples=3000, - truncated_samples=3000, -) -c4constants.splits['val_xxsmall'] = DataSplitConstants( - hf_split='validation', - folder_split='val_xxsmall', - raw_samples=100, - truncated_samples=100, -) - -CONSTS = {'c4': c4constants, 'the_pile': pileconstants} - - -def build_hf_dataset( - dataset_name: str, - split: str, - mode: ConcatMode, - max_length: Optional[int] = None, - bos_text: str = '', - eos_text: str = '', - no_wrap: bool = False, - tokenizer: PreTrainedTokenizerBase = None, - data_subset: Union[str, None] = None, -) -> IterableDataset: - """Build an IterableDataset over the HF C4 or pile source data. - - Args: - dataset_name (str): Dataset name - split (str): Split name. - mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS - max_length (int): The length of concatenated tokens - bos_text (str): text to insert at the beginning of each sequence - eos_text (str): text to insert at the end of each sequence - no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries - tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use - data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). - - Returns: - An IterableDataset. - """ - hf_dataset = hf_datasets.load_dataset( - path=dataset_name, - name=data_subset, - split=split, - streaming=True, - ) - if mode == ConcatMode.NO_CONCAT: - dataset = NoConcatDataset(hf_dataset) - else: - if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase', - ) - if max_length is None: - raise ValueError(f'max_length must be set.') - if bos_text + eos_text == '': - test_tokens = tokenizer('test') - if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: - tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' - tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' - tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' - tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' - tok_error_msg += '--bos_text=<|endoftext|>.' - raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset( - hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap, - ) - return dataset - - -def _est_progress_denominator( - total_samples: int, - chars_per_sample: int, - chars_per_token: int, - mode: ConcatMode, - max_length: int, -): - est_tokens_per_sample = chars_per_sample // chars_per_token - if mode == ConcatMode.NO_CONCAT: - return total_samples - elif mode == ConcatMode.CONCAT_TOKENS: - return total_samples * est_tokens_per_sample // max_length - - -def build_dataloader( - dataset: Dataset, - batch_size: int, - num_workers: Optional[int], -) -> DataLoader: - if num_workers is None: - # Multiple workers is only supported on linux machines - if 'linux' or 'macos' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) - else: - num_workers = 0 - - # If using multiple workers, configure each worker to prefetch as many samples as it can, up to - # the aggregate device batch size - # If not using workers, the torch DataLoader expects the default value for prefetch_factor, - # which non-intuitively must be 2. - prefetch_factor = max( - 1, - 2 * batch_size // num_workers, - ) if num_workers > 0 else 2 - - return DataLoader( - dataset=dataset, - sampler=None, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) - - -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - - -def main(args: Namespace) -> None: - """Main: create C4/pile streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - try: - dataset_constants = CONSTS[args.dataset] - except KeyError: - raise ValueError( - f'Constants for dataset "{args.dataset}" not found. Currently only "the_pile" and "c4" are supported.', - ) - - if args.concat_tokens is not None: - mode = ConcatMode.CONCAT_TOKENS - tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs) - # we will enforce length, so suppress warnings about sequences too long for the model - tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'bytes'} - else: - mode = ConcatMode.NO_CONCAT - tokenizer = None - columns = {'text': 'str'} - - for split_name in args.splits: - try: - split = dataset_constants.splits[split_name] - except KeyError: - raise KeyError(f'Constants not defined for split {split_name}.') - hf_split = split.hf_split - folder_split = split.folder_split - expected_num_samples = split.raw_samples - truncate_num_samples = split.truncated_samples - # Only generate the splits requested - if folder_split not in args.splits: - continue - - # Get samples - dataset = build_hf_dataset( - dataset_name=args.dataset, - data_subset=args.data_subset, - split=hf_split, - mode=mode, - max_length=args.concat_tokens, - bos_text=args.bos_text, - eos_text=args.eos_text, - no_wrap=args.no_wrap, - tokenizer=tokenizer, - ) - loader = build_dataloader( - dataset=dataset, - batch_size=512, - num_workers=args.num_workers, - ) - samples = generate_samples( - loader, - truncate_num_samples=truncate_num_samples, - ) - - if expected_num_samples is not None: - denominator = truncate_num_samples if truncate_num_samples is not None else _est_progress_denominator( - total_samples=expected_num_samples, - chars_per_sample=dataset_constants.chars_per_sample, - chars_per_token=dataset_constants.chars_per_token, - mode=mode, - max_length=args.concat_tokens, - ) - else: - denominator = None - - # Write samples - print(f'Converting {folder_split} to MDS format...') - print( - f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.', - ) - with MDSWriter( - columns=columns, - out=os.path.join(args.out_root, folder_split), - compression=args.compression, - ) as out: - if denominator is not None: - for sample in tqdm( - samples, - desc=folder_split, - total=denominator, - ): - out.write(sample) - else: - for sample in tqdm(samples, desc=folder_split): - out.write(sample) - - if __name__ == '__main__': - main(parse_args()) + args = parse_args() + convert_dataset_hf_from_args( + dataset=args.dataset, + data_subset=args.data_subset, + splits=args.splits, + out_root=args.out_root, + compression=args.compression, + concat_tokens=args.concat_tokens, + tokenizer=args.tokenizer, + tokenizer_kwargs=args.tokenizer_kwargs, + bos_text=args.bos_text, + eos_text=args.eos_text, + no_wrap=args.no_wrap, + num_workers=args.num_workers, + ) diff --git a/scripts/data_prep/convert_dataset_json.py b/scripts/data_prep/convert_dataset_json.py index fb117ddef3..5a6927ac75 100644 --- a/scripts/data_prep/convert_dataset_json.py +++ b/scripts/data_prep/convert_dataset_json.py @@ -2,24 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """Streaming dataset conversion scripts for json files.""" -import os from argparse import ArgumentParser, Namespace -from enum import Enum -from glob import glob -from typing import Dict, Iterable, Optional -import datasets as hf_datasets -from streaming import MDSWriter -from torch.utils.data import DataLoader, IterableDataset -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from llmfoundry.data import ConcatTokensDataset, NoConcatDataset - - -class ConcatMode(Enum): - NO_CONCAT = 'NO_CONCAT' - CONCAT_TOKENS = 'CONCAT_TOKENS' +from llmfoundry.command_utils import convert_dataset_json_from_args def parse_args() -> Namespace: @@ -46,169 +31,19 @@ def parse_args() -> Namespace: parser.add_argument('--no_wrap', default=False, action='store_true') parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.split)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - # Make sure we have needed concat options - if ( - parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None - ): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer', - ) - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -def build_hf_dataset( - path: str, - split: str, - mode: ConcatMode, - max_length: Optional[int] = None, - bos_text: str = '', - eos_text: str = '', - no_wrap: bool = False, - tokenizer: PreTrainedTokenizerBase = None, -) -> IterableDataset: - """Build an IterableDataset over the HF C4 or pile source data. - - Args: - dataset_name (str): Dataset name - split (str): Split name. - mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS - max_length (int): The length of concatenated tokens - bos_text (str): text to insert at the beginning of each sequence - eos_text (str): text to insert at the end of each sequence - no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries - tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use - data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). - - Returns: - An IterableDataset. - """ - if os.path.isdir(path): - data_files = glob(f'{path}/*') - else: - data_files = path - - hf_dataset = hf_datasets.load_dataset( - 'json', - data_files=data_files, - split=split, - ) - - if mode == ConcatMode.NO_CONCAT: - dataset = NoConcatDataset(hf_dataset) - else: - if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase', - ) - if max_length is None: - raise ValueError(f'max_length must be set.') - if bos_text + eos_text == '': - test_tokens = tokenizer('test') - if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: - tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' - tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' - tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' - tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' - tok_error_msg += '--bos_text=<|endoftext|>.' - raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset( - hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap, - ) - return dataset - - -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - - -def main(args: Namespace) -> None: - """Main: create C4/pile streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - if args.concat_tokens is not None: - mode = ConcatMode.CONCAT_TOKENS - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - # we will enforce length, so suppress warnings about sequences too long for the model - tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'bytes'} - else: - mode = ConcatMode.NO_CONCAT - tokenizer = None - columns = {'text': 'str'} - - # Get samples - dataset = build_hf_dataset( +if __name__ == '__main__': + args = parse_args() + convert_dataset_json_from_args( path=args.path, + out_root=args.out_root, + compression=args.compression, + concat_tokens=args.concat_tokens, split=args.split, - mode=mode, - max_length=args.concat_tokens, + tokenizer=args.tokenizer, bos_text=args.bos_text, eos_text=args.eos_text, no_wrap=args.no_wrap, - tokenizer=tokenizer, ) - - print('here') - - # Write samples - print(f'Converting to MDS format...') - print( - f'Note that the progress bar is based on the dataset length before tokenization.', - ) - print(f'It will finish at a value below 100% if tokenizing') - with MDSWriter( - columns=columns, - out=os.path.join(args.out_root), - compression=args.compression, - ) as out: - for sample in tqdm(dataset): - out.write(sample) - - -if __name__ == '__main__': - main(parse_args()) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index d871761803..277a8c1ffc 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -4,571 +4,19 @@ import logging import os import re -import time -import urllib.parse -from argparse import ArgumentParser, Namespace -from collections import namedtuple -from concurrent.futures import ProcessPoolExecutor -from typing import Iterable, List, Optional, Tuple, Union -from uuid import uuid4 +from argparse import ArgumentParser -import google.protobuf.any_pb2 as any_pb2 -import lz4.frame -import pandas as pd -import pyarrow as pa -import pyspark.sql.connect.proto as pb2 -import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 -import requests -from databricks import sql -from databricks.connect import DatabricksSession -from databricks.sdk import WorkspaceClient from databricks.sql.client import Connection as Connection from databricks.sql.client import Cursor as Cursor -from packaging import version -from pyspark.sql import SparkSession -from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import \ - ExecutePlanResponseReattachableIterator -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.dataframe import DataFrame as SparkDataFrame -from pyspark.sql.types import Row -from llmfoundry.utils import maybe_create_mosaicml_logger -from llmfoundry.utils.exceptions import ( - ClusterDoesNotExistError, - FailedToConnectToDatabricksError, - FailedToCreateSQLConnectionError, -) +from llmfoundry.command_utils import convert_delta_to_json_from_args MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' -log = logging.getLogger(__name__) - -Result = namedtuple( - 'Result', - [ - 'url', - 'row_count', - 'compressed_size', - 'uncompressed_size', - ], -) # pyright: ignore - -# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. -# It allows the client to fetch the results in different formats from the server. -# To be able to use the code make sure this module is not overriden by DB Connect classes. - - -def to_cf(self: SparkConnectClient, - plan: pb2.Plan, - type: str = 'json') -> Tuple[List[Result], int, bool]: - """Executes the query plans and return as presigned URLS for cloud fetch. - - It can handle the current output formats that are supported by the server. - In contrast to the regular API methods of the client, this method does not - return the schema and drops all other responses. - - Args: - plan (pb2.Plan): The plan object to be executed by spark. - type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. - - Returns: - Tuple[List[Result], int, bool]: A tuple containing: - - A list of Result namedtuples, each containing a URL, row count, compressed size, - and uncompressed size of the part of the result. - - Total row count of all parts of the result. - - A boolean indicating whether the result has been truncated. - """ - req = self._execute_plan_request_with_metadata() - req.plan.CopyFrom(plan) - - # Add the request options - if type == 'json': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON - elif type == 'csv': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV - elif type == 'arrow': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW - else: - raise ValueError( - f'Only formats json, csv, and arrow are supported. Got invalid type {type}', - ) - - ro = cloud_pb2.ResultOptions( - type=cloud_pb2.ResultOptions.TYPE_CLOUD, - cloudOptions=cloud_pb2.ResultOptions.CloudOptions( - format=format, - useCompression=False, - ), - ) - cloud_option = any_pb2.Any() - cloud_option.Pack(ro) - req.request_options.append( - pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), - ) - - # Create the iterator - iterator = ExecutePlanResponseReattachableIterator( - req, - self._stub, - self._retry_policy, - self._builder.metadata(), - ) - # Iterate over the response - result = [] - row_count = 0 - is_overflow = False - - for response in iterator: - if response.HasField('extension') and response.extension.Is( - cloud_pb2.CloudResultBatch.DESCRIPTOR, - ): - batch = cloud_pb2.CloudResultBatch() - if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): - raise ValueError( - 'Response extension is not of type CloudResultBatch.', - ) - response.extension.Unpack(batch) - result += [ - Result( - b.url, - b.row_count, - b.compressed_size, - b.uncompressed_size, - ) for b in batch.results - ] - row_count += sum(result.row_count for result in batch.results) - is_overflow |= batch.truncated - return result, row_count, is_overflow - - -SparkConnectClient.to_cf = to_cf # pyright: ignore - - -def collect_as_cf(self: DataFrame, - type: str = 'json') -> Tuple[List[Result], int, bool]: - """Collects DataFrame execution plan as presigned URLs. - - This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the - execution plan of the current DataFrame, converts it to a protocol buffer format, and then - uses the `to_cf` method to execute the plan and fetch results as presigned URLs. - - Args: - type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. - - Returns: - Tuple[List[Result], int, bool]: A tuple containing: - - A list of Result namedtuples, each containing a URL, row count, compressed size, - and uncompressed size of the part of the result. - - Total row count of all parts of the result. - - A boolean indicating whether the result is truncated or overflowed. - """ - query = self._plan.to_proto(self._session.client) # pyright: ignore - return self._session.client.to_cf(query, type) # pyright: ignore - - -DataFrame.collect_cf = collect_as_cf # pyright: ignore - - -def iterative_combine_jsons(json_directory: str, output_file: str) -> None: - """Combine jsonl files in json_directory into one big jsonl file. - - This function does not work for nested subdirectories. - - Args: - json_directory(str): directory containing the JSONL files - output_file(str): path to the output combined JSONL file - """ - json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] - with open(output_file, 'w') as outfile: - for file_name in json_files: - with open(os.path.join(json_directory, file_name), 'r') as infile: - for line in infile: - outfile.write(line) - log.info('JSON files have been combined into a JSONL file.') - - -def run_query( - query: str, - method: str, - cursor: Optional[Cursor] = None, - spark: Optional[SparkSession] = None, - collect: bool = True, -) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: - """Run SQL query via databricks-connect or databricks-sql. - - Args: - query (str): sql query - method (str): select from dbsql and dbconnect - cursor (Optional[Cursor]): connection.cursor - spark (Optional[SparkSession]): spark session - collect (bool): whether to get the underlying data from spark dataframe - """ - if method == 'dbsql': - if cursor is None: - raise ValueError(f'cursor cannot be None if using method dbsql') - cursor.execute(query) - if collect: - return cursor.fetchall() - elif method == 'dbconnect': - if spark == None: - raise ValueError(f'sparkSession is required for dbconnect') - df = spark.sql(query) - if collect: - return df.collect() - return df - else: - raise ValueError(f'Unrecognized method: {method}') - - -def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: - for i, r in enumerate(signed): - yield (i, r.url, json_output_folder, columns) - - -def download( - ipart: int, - url: str, - json_output_folder: str, - columns: Optional[List] = None, - resp_format: str = 'arrow', - compressed: bool = False, -) -> None: - """Thread download presigned url and save to jsonl locally. - - Args: - ipart (int): presigned url id - url (str): presigned url - json_output_folder (str): directory to save the ipart_th segment of dataframe - columns (list): schema to save to json - resp_format (str): whether to use arrow or json when collect - compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. - """ - resp = requests.get(url) - if resp.status_code == 200: - if resp_format == 'json': - data = resp.json() - pd.DataFrame(data, columns=columns).to_json( - os.path.join( - json_output_folder, - 'part_' + str(ipart) + '.jsonl', - ), - orient='records', - lines=True, - ) - return - - # When resp_format is arrow: - if compressed: - # The data is lz4 compressed arrow format. - # Decompress the data - decompressed_data = lz4.frame.decompress(resp.content) - # Convert the decompressed data into a PyArrow table - reader = pa.ipc.open_stream(decompressed_data) - else: - reader = pa.ipc.open_stream(resp.content) - table = reader.read_all() - - # Convert the PyArrow table into a pandas DataFrame - df = table.to_pandas() - df.to_json( - os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), - orient='records', - lines=True, - force_ascii=False, - ) - - -def download_starargs(args: Tuple) -> None: - return download(*args) - - -def fetch_data( - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], - start: int, - end: int, - order_by: str, - tablename: str, - columns_str: str, - json_output_folder: str, -) -> None: - """Fetches a specified range of rows from a given table to a json file. - - This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, - from a specified table and column set. The fetched data is then exported as a JSON file. - - Args: - method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. - cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. - sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. - start (int): The starting index for row fetching. - end (int): The ending index for row fetching. - order_by (str): The column name to use for ordering the rows. - tablename (str): The name of the table from which to fetch the data. - columns_str (str): The string representation of the columns to select from the table. - json_output_folder (str): The file path where the resulting JSON file will be saved. - - Returns: - None: The function doesn't return any value, but writes the result to a JSONL file. - """ - query = f""" - WITH NumberedRows AS ( - SELECT - *, - ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn - FROM - {tablename} - ) - SELECT {columns_str} - FROM NumberedRows - WHERE rn BETWEEN {start+1} AND {end}""" - - if method == 'dbconnect': - spark_df = run_query(query, method, cursor, sparkSession, collect=False) - if spark_df is None: - raise RuntimeError( - f'Expect spark dataframe with {query} but got None', - ) - pdf = spark_df.toPandas() # pyright: ignore - else: # method == 'dbsql': - ans = run_query(query, method, cursor, sparkSession, collect=True) - if ans is None: - raise RuntimeError(f'Got empty results with {query}') - records = [r.asDict() for r in ans] # pyright: ignore - pdf = pd.DataFrame.from_dict(records) - - pdf.to_json( - os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), - orient='records', - lines=True, - ) - - -def fetch( - method: str, - tablename: str, - json_output_folder: str, - batch_size: int = 1 << 30, - processes: int = 1, - sparkSession: Optional[SparkSession] = None, - dbsql: Optional[Connection] = None, -) -> None: - """Fetch UC delta table with databricks-connect as JSONL. - - Args: - method (str): dbconnect or dbsql - tablename (str): catalog.scheme.tablename on UC - json_output_folder (str): path to write the result json file to - batch_size (int): number of rows that dbsql fetches each time to avoid OOM - processes (int): max number of processes to use to parallelize the fetch - sparkSession (pyspark.sql.sparksession): spark session - dbsql (databricks.sql.connect): dbsql session - """ - cursor = dbsql.cursor() if dbsql is not None else None - - try: - ans = run_query( - f'SELECT COUNT(*) FROM {tablename}', - method, - cursor, - sparkSession, - ) - nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.info(f'total_rows = {nrows}') - except Exception as e: - raise RuntimeError( - f'Error in get total rows from {tablename}. Restart sparkSession and try again', - ) from e - - try: - ans = run_query( - f'SHOW COLUMNS IN {tablename}', - method, - cursor, - sparkSession, - ) - columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore - order_by = columns[0] - columns_str = ','.join(columns) - log.info(f'order by column {order_by}') - except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e - - if method == 'dbconnect' and sparkSession is not None: - log.info(f'{processes=}') - df = sparkSession.table(tablename) - - # Running the query and collecting the data as arrow or json. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore - log.info(f'len(signed) = {len(signed)}') - - args = get_args(signed, json_output_folder, columns) - - # Stopping the SparkSession to avoid spilling connection state into the subprocesses. - sparkSession.stop() - - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_starargs, args)) - - elif method == 'dbsql' and cursor is not None: - for start in range(0, nrows, batch_size): - log.warning(f'batch {start}') - end = min(start + batch_size, nrows) - fetch_data( - method, - cursor, - sparkSession, - start, - end, - order_by, - tablename, - columns_str, - json_output_folder, - ) - - if cursor is not None: - cursor.close() - - -def validate_and_get_cluster_info( - cluster_id: str, - databricks_host: str, - databricks_token: str, - http_path: Optional[str], - use_serverless: bool = False, -) -> tuple: - """Validate and get cluster info for running the Delta to JSONL conversion. - - Args: - cluster_id (str): cluster id to validate and fetch additional info for - databricks_host (str): databricks host name - databricks_token (str): databricks auth token - http_path (Optional[str]): http path to use for sql connect - use_serverless (bool): whether to use serverless or not - """ - method = 'dbsql' - dbsql = None - sparkSession = None - - if use_serverless: - method = 'dbconnect' - else: - w = WorkspaceClient() - res = w.clusters.get(cluster_id=cluster_id) - if res is None: - raise ClusterDoesNotExistError(cluster_id) - - stripped_runtime = re.sub( - r'[a-zA-Z]', - '', - res.spark_version.split('-scala') - [0].replace( # type: ignore - 'x-snapshot', '', - ), - ) - runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) - if version.parse( - runtime_version, - ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): - raise ValueError( - f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', - ) - - if http_path is None and version.parse( - runtime_version, - ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): - method = 'dbconnect' - - if method == 'dbconnect': - try: - if use_serverless: - session_id = str(uuid4()) - sparkSession = DatabricksSession.builder.host( - databricks_host, - ).token( - databricks_token, - ).header('x-databricks-session-id', session_id).getOrCreate() - - else: - sparkSession = DatabricksSession.builder.remote( - host=databricks_host, - token=databricks_token, - cluster_id=cluster_id, - ).getOrCreate() - - except Exception as e: - raise FailedToConnectToDatabricksError() from e - else: - try: - dbsql = sql.connect( - server_hostname=re.compile(r'^https?://').sub( - '', databricks_host).strip( - ), # sqlconnect hangs if hostname starts with https - http_path=http_path, - access_token=databricks_token, - ) - except Exception as e: - raise FailedToCreateSQLConnectionError() from e - return method, dbsql, sparkSession - - -def fetch_DT(args: Namespace) -> None: - """Fetch UC Delta Table to local as jsonl.""" - log.info(f'Start .... Convert delta to json') - - obj = urllib.parse.urlparse(args.json_output_folder) - if obj.scheme != '': - raise ValueError( - 'Check the json_output_folder and verify it is a local path!', - ) - - if os.path.exists(args.json_output_folder): - if not os.path.isdir(args.json_output_folder) or os.listdir( - args.json_output_folder, - ): - raise RuntimeError( - f'Output folder {args.json_output_folder} already exists and is not empty. Please remove it and retry.', - ) - - os.makedirs(args.json_output_folder, exist_ok=True) - - if not args.json_output_filename.endswith('.jsonl'): - raise ValueError('json_output_filename needs to be a jsonl file') - - log.info(f'Directory {args.json_output_folder} created.') - - method, dbsql, sparkSession = validate_and_get_cluster_info( - cluster_id=args.cluster_id, - databricks_host=args.DATABRICKS_HOST, - databricks_token=args.DATABRICKS_TOKEN, - http_path=args.http_path, - use_serverless=args.use_serverless, - ) - - fetch( - method, - args.delta_table_name, - args.json_output_folder, - args.batch_size, - args.processes, - sparkSession, - dbsql, - ) - - if dbsql is not None: - dbsql.close() - - # combine downloaded jsonl into one big jsonl for IFT - iterative_combine_jsons( - args.json_output_folder, - os.path.join(args.json_output_folder, args.json_output_filename), - ) +TABLENAME_PATTERN = re.compile(r'(\S+)\.(\S+)\.(\S+)') +log = logging.getLogger(__name__) if __name__ == '__main__': parser = ArgumentParser( @@ -631,18 +79,13 @@ def fetch_DT(args: Namespace) -> None: 'The name of the combined final jsonl that combines all partitioned jsonl', ) args = parser.parse_args() - mosaicml_logger = maybe_create_mosaicml_logger() - - try: - w = WorkspaceClient() - args.DATABRICKS_HOST = w.config.host - args.DATABRICKS_TOKEN = w.config.token - - tik = time.time() - fetch_DT(args) - log.info(f'Elapsed time {time.time() - tik}') - - except Exception as e: - if mosaicml_logger is not None: - mosaicml_logger.log_exception(e) - raise e + convert_delta_to_json_from_args( + delta_table_name=args.delta_table_name, + json_output_folder=args.json_output_folder, + http_path=args.http_path, + batch_size=args.batch_size, + processes=args.processes, + cluster_id=args.cluster_id, + use_serverless=args.use_serverless, + json_output_filename=args.json_output_filename, + ) diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index 523d45093d..b28e25786b 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -1,28 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json -import os -import platform -import warnings from argparse import ArgumentParser, Namespace -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Union -import datasets as hf_datasets -import psutil from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from streaming import MDSWriter -from torch.utils.data import DataLoader -from tqdm import tqdm -from llmfoundry.data.finetuning.collator import validate_target_settings -from llmfoundry.data.finetuning.tasks import ( - _get_example_type, - dataset_constructor, - is_valid_ift_example, - tokenize_formatted_example, -) -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.command_utils import convert_finetuning_dataset_from_args HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] @@ -116,236 +100,9 @@ def parse_args() -> Namespace: ) parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - if parsed.tokenizer_kwargs is not None: - parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs) - else: - parsed.tokenizer_kwargs = {} - - if len(parsed.data_files) > 0 and len( - parsed.data_files, - ) != len(parsed.splits): - raise ValueError( - f'If data_files is set, data_files and splits must have the same length. Got {len(parsed.data_files)=} while {len(parsed.splits)=}', - ) - return parsed -def build_dataloader( - dataset: HFDataset, - batch_size: int, - num_workers: Optional[int] = None, -) -> DataLoader: - if num_workers is None: - # Multiple workers is only supported on linux machines - if 'linux' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) - else: - num_workers = 0 - - # If using multiple workers, configure each worker to prefetch as many samples as it can, up to - # the aggregate device batch size - # If not using workers, the torch DataLoader expects the default value for prefetch_factor, - # which non-intuitively must be 2. - # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero - if 'macos' in platform.platform().lower() and num_workers == 0: - prefetch_factor = None - else: - prefetch_factor = max( - 1, - 2 * batch_size // num_workers, - ) if num_workers > 0 else 2 - - return DataLoader( - dataset=dataset, - sampler=None, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) - - -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - - -def get_columns_and_format( - dataset: HFDataset, - tokenizing: bool, - preprocessing_fn: Callable, -): - ex = preprocessing_fn(next(iter(dataset))) - example_type = _get_example_type(ex) - if tokenizing: - return {'turns': 'json'}, example_type - if example_type == 'chat': - # Chat format - return {'messages': 'json'}, example_type - else: - # Prompt-response format - return {'prompt': 'str', 'response': 'str'}, example_type - - -def main(args: Namespace) -> None: - """Main: create a streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - if args.skip_preprocessing: - preprocessing_fn = lambda x: x # Just an identity function - else: - preprocessor_str = args.preprocessor - preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( - preprocessor=preprocessor_str, - dataset_name=args.dataset, - ) - if preprocessing_fn is None: - raise ValueError( - '`args.preprocessor` was not set and no preprocessing function ' +\ - 'has been registered for `args.dataset`. If this was intentional ' +\ - '(e.g., because your dataset is already correctly formatted), ' +\ - 'include the "--skip-preprocessing" flag to avoid this error.', - ) - - # Make sure the target settings are valid - validate_target_settings( - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - ) - - tokenizer = None - tokenizer_kwargs = args.tokenizer_kwargs - tokenizer_kwargs.update({'model_max_length': args.max_seq_len}) - if args.tokenizer: - tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs) - - for i, split_name in enumerate(args.splits): - data_file = None - if len(args.data_files) > 0: - data_file = args.data_files[i] - dataset = hf_datasets.load_dataset( - path=args.dataset, - name=args.data_subset, - split=split_name, - data_files=data_file, - streaming=True, - ) - # Determine the output columns - columns, example_type = get_columns_and_format( - dataset=dataset, - tokenizing=tokenizer is not None, - preprocessing_fn=preprocessing_fn, - ) - # Prepare the iterables - if example_type == 'chat': - samples = iter(dataset) - else: - loader = build_dataloader( - dataset=dataset, - batch_size=512, - num_workers=args.num_workers, - ) - samples = generate_samples(loader) - - # Write samples - print(f'Converting {split_name} to MDS format...') - out = os.path.join(args.out_root, split_name) - if args.local is not None: - out = (os.path.join(args.local, split_name), out) - keep_local = True - else: - keep_local = False - with MDSWriter( - columns=columns, - out=out, - compression=args.compression, - keep_local=keep_local, - ) as out: - examples_removed = 0 - for sample in tqdm(samples, desc=split_name): - formatted_sample = preprocessing_fn(sample) - assert isinstance(formatted_sample, dict) - - # Use the _get_example_type utility to confirm that the formatted sample - # can be interpreted by the tokenization code - try: - example_type = _get_example_type(formatted_sample) - except Exception as e: - raise ValueError( - 'Encountered an error when checking example for proper formatting. ' +\ - f'example={formatted_sample}', - ) from e - if tokenizer is not None: - sample = tokenize_formatted_example( - formatted_sample, - tokenizer=tokenizer, - ) - if not is_valid_ift_example( - args.max_seq_len, - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - example=sample, - ): - examples_removed += 1 - continue - - sample_to_write = {'turns': []} - for turn in sample['turns']: - turn_to_write = {} - for key in ['input_ids', 'labels']: - turn_to_write[key] = list(turn[key]) - sample_to_write['turns'].append(turn_to_write) - out.write(sample_to_write) - else: - if example_type == 'prompt_response': - encoded_sample = {} - for key in ['prompt', 'response']: - value = formatted_sample[key] - assert isinstance(value, str) - encoded_sample[key] = value.encode('utf-8') - out.write(encoded_sample) - else: - out.write(formatted_sample) - - if tokenizer is not None and examples_removed > 0: - warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, ' - + - 'the prompt or response was empty, or the response was all padding tokens.', - ) - - if __name__ == '__main__': """Example for converting Muennighoff/P3: @@ -355,4 +112,22 @@ def main(args: Namespace) -> None: >>> --preprocessor llmfoundry.data.finetuning.tasks:p3_preprocessing_function \ >>> --out_root s3:///muennighoff-p3 """ - main(parse_args()) + args = parse_args() + convert_finetuning_dataset_from_args( + dataset=args.dataset, + data_subset=args.data_subset, + splits=args.splits, + preprocessor=args.preprocessor, + data_files=args.data_files, + skip_preprocessing=args.skip_preprocessing, + out_root=args.out_root, + local=args.local, + compression=args.compression, + num_workers=args.num_workers, + tokenizer=args.tokenizer, + tokenizer_kwargs=args.tokenizer_kwargs, + max_seq_len=args.max_seq_len, + target_prompts=args.target_prompts, + target_responses=args.target_responses, + encoder_decoder=args.encoder_decoder, + ) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 365cc9b71d..c808fa871f 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -2,105 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import math -import os -import tempfile from argparse import ArgumentParser, Namespace -from concurrent.futures import ProcessPoolExecutor -from functools import partial -from glob import glob -from typing import Dict, Iterable, List, Tuple, cast -import numpy as np import psutil -from composer.utils import ( - ObjectStore, - maybe_create_object_store_from_uri, - parse_uri, -) -from streaming import MDSWriter -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.data.data import AbstractConcatTokensDataset -from llmfoundry.utils import maybe_create_mosaicml_logger -from llmfoundry.utils.data_prep_utils import ( - DownloadingIterable, - download_file, - merge_shard_groups, -) -from llmfoundry.utils.exceptions import ( - InputFolderMissingDataError, - OutputFolderNotEmptyError, -) +from llmfoundry.command_utils import convert_text_to_mds_from_args log = logging.getLogger(__name__) DONE_FILENAME = '.text_to_mds_conversion_done' -class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): - """An IterableDataset that returns token samples for MDSWriter from files. - - Returns dicts of {'tokens': bytes} - - Each file is considered a sequence. - """ - - def __init__( - self, - files: Iterable[str], - tokenizer: PreTrainedTokenizerBase, - max_length: int, - bos_text: str, - eos_text: str, - no_wrap: bool, - ): - self.files = files - super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) - - def __iter__(self) -> Iterable[Dict[str, bytes]]: - - buffer = [] - for file in self.files: - with open(file, 'r') as f: - buffer += self.bos_tokens - first_chunk = True - # Read the file in 1MB chunks to avoid memory issues - for chunk in iter(partial(f.read, 1000000), ''): - # Tokenize the chunk - encoded = self.tokenizer( - chunk, - truncation=False, - padding=False, - ) - iids = encoded['input_ids'] - - # If this is not the first chunk, remove the BOS token - if not first_chunk: - if iids[0] == self.tokenizer.bos_token_id: - iids = iids[1:] - - # Add the tokens to the buffer - buffer += iids - while len(buffer) >= self.max_length: - concat_sample = buffer[:self.max_length] - buffer = buffer[self. - max_length:] if self.should_wrap else [] - yield {'tokens': np.asarray(concat_sample).tobytes()} - - first_chunk = False - - # Add the EOS token to the buffer to separate files. - buffer += self.eos_tokens - - # Yield any remaining samples of size max_length. - while len(buffer) >= self.max_length: - concat_sample = buffer[:self.max_length] - buffer = buffer[self.max_length:] if self.should_wrap else [] - yield {'tokens': np.asarray(concat_sample).tobytes()} - - def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( @@ -201,426 +113,23 @@ def parse_args() -> Namespace: help='Logging level for the script. Default is INFO.', ) parsed = parser.parse_args() - - # Set eos token. - if parsed.use_tokenizer_eos: - # Ensure that eos text is not specified twice. - if parsed.eos_text is not None: - parser.error( - 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', - ) - tokenizer = AutoTokenizer.from_pretrained( - parsed.tokenizer, - trust_remote_code=parsed.trust_remote_code, - ) - parsed.eos_text = tokenizer.eos_token - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -def get_object_names(input_folder: str) -> List[str]: - """Get object names from a local or remote folder. - - Args: - input_folder (str): local or remote folder path. - """ - object_store = maybe_create_object_store_from_uri(input_folder) - if object_store is not None: - _, _, folder_prefix = parse_uri(input_folder) - names = [ - name for name in object_store.list_objects(folder_prefix) - if name.endswith('.txt') - ] - else: - # input_folder is a local folder - names = [ - text_file for dirpath, _, _ in os.walk(input_folder) - for text_file in glob(os.path.join(dirpath, '*.txt')) - ] - # return names, sizes - log.info(f'Found {len(names)} text files at {input_folder}') - - return names - - -def get_task_args( - object_names: List[str], - output_root: str, - input_folder: str, - n_groups: int, - tokenizer_name: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - trust_remote_code: bool, -) -> Iterable: - """Get download_and_convert arguments split across n_groups. - - Each group handles a portion of object_names. - - Args: - object_names (List[str]): Names of objects to process - output_root (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - n_groups (int): Number of groups to split the object names into - tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - num_objects = len(object_names) - objs_per_group = math.ceil(num_objects / n_groups) - for group, i in enumerate(range(0, num_objects, objs_per_group)): - output_subdir = os.path.join(output_root, str(group)) - yield ( - object_names[i:min(i + objs_per_group, num_objects)], - output_subdir, - input_folder, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - - -def download_and_convert_starargs(args: Tuple): - """Helper function to call download_and_convert with star args. - - This helps us use download_and_convert with multiprocessing. - """ - return download_and_convert(*args) - - -def download_and_convert( - file_names: List[str], - output_folder: str, - input_folder: str, - tokenizer_name: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - trust_remote_code: bool, -): - """Downloads and converts text files to MDS format. - - Args: - file_names (List[str]): Files to process - output_folder (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - object_store = maybe_create_object_store_from_uri(input_folder) - - # Download file_names - with tempfile.TemporaryDirectory() as tmp_dir: - downloading_iter = DownloadingIterable( - object_names=file_names, - output_folder=tmp_dir, - object_store=object_store, - ) - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - trust_remote_code=trust_remote_code, - ) - tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace - - # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up - # to the maximum sequence length - dataset = ConcatTokensFromFilesDataset( - files=downloading_iter, - max_length=concat_tokens, - tokenizer=tokenizer, - eos_text=eos_text, - bos_text=bos_text, - no_wrap=no_wrap, - ) - - columns = {'tokens': 'bytes'} - - log.info('Converting to MDS format...') - with MDSWriter( - out=output_folder, - columns=columns, - compression=compression, - ) as out: - for sample in tqdm(dataset): - out.write(sample) - - -def is_remote_path(path: str) -> bool: - """Checks whether a path is a remote path. - - Args: - path (str): path to check - """ - backend, _, _ = parse_uri(path) - return backend != '' - - -def is_already_processed( - output_root: str, - args_str: str, - object_names: List[str], -) -> bool: - """Determines whether a group of text files has already been processed. - - Checks the done fie at output root to determine this. - - Args: - output_root (str): Output folder where a done file may exist - args_str (str): String representation of the arguments - object_names (List[str]): Names of objects to convert to MDS format - """ - # Retrieve the done file contents - output_object_store = maybe_create_object_store_from_uri(output_root) - if output_object_store is not None: - # Download and read the done file from the remote object store - _, _, output_folder_prefix = parse_uri(output_root) - try: - with tempfile.TemporaryDirectory() as tmp_dir: - done_file = os.path.join(tmp_dir, DONE_FILENAME) - download_file( - object_store=output_object_store, - object_name=os.path.join( - output_folder_prefix, - DONE_FILENAME, - ), - output_filename=done_file, - ) - with open(done_file) as df: - done_file_contents = df.read().splitlines() - except FileNotFoundError: - return False - else: - # Read the local done file - done_file = os.path.join(output_root, DONE_FILENAME) - if not os.path.isfile(done_file): - return False - with open(done_file) as df: - done_file_contents = df.read().splitlines() - # Compare the arguments - prev_args_str = done_file_contents[0] - if prev_args_str != args_str: - return False - - # Compare file names - prev_names = done_file_contents[1:] - if len(prev_names) != len(object_names): - return False - for idx, prev_name in enumerate(prev_names): - if object_names[idx] != prev_name: - return False - return True - - -def write_done_file(folder: str, args_str: str, object_names: List[str]): - """Write a file to signify completion. - - This the done file includes the arguments to processing and - a list of objects that were processed. - - Args: - folder (str): Folder to write the done file to - args_str (str): String representation of arguments - object_names (List[str]): List of objects to convert to MDS format - """ - with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: - done_file.write('\n'.join([args_str] + object_names) + '\n') - - -def convert_text_to_mds( - tokenizer_name: str, - output_folder: str, - input_folder: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - processes: int, - args_str: str, - reprocess: bool, - trust_remote_code: bool, -): - """Convert a folder of text files to MDS format. - - Args: - tokenizer_name (str): Name of tokenizer to use - output_folder (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - processes (int): The number of processes to use. - args_str (str): String representation of the arguments - reprocess (bool): Whether to always reprocess the given folder of text files - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - is_remote_output = is_remote_path(output_folder) - - object_names = get_object_names(input_folder) - if len(object_names) == 0: - raise InputFolderMissingDataError(input_folder) - - # Check if the text files in the bucket have already been processed. - if not reprocess and is_already_processed( - output_folder, - args_str, - object_names, - ): - log.info( - f'Input folder {input_folder} is already processed at {output_folder} and ' - + - 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', - ) - return - - # Use a temporary local directory if the output is remote and there are more than 1 processes - local_output_folder = tempfile.TemporaryDirectory( - ).name if is_remote_output else output_folder - - if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: - raise OutputFolderNotEmptyError(output_folder) - - if processes > 1: - # Download and convert the text files in parallel - args = get_task_args( - object_names, - local_output_folder, - input_folder, - processes, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_and_convert_starargs, args)) - - # Merge the mds shards from each of the processes into a single folder - merge_shard_groups(local_output_folder) - else: - download_and_convert( - object_names, - local_output_folder, - input_folder, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - - # Write a done file with the args and object names - write_done_file(local_output_folder, args_str, object_names) - - if is_remote_output: - # Upload the local output to the remote location - output_object_store = cast( - ObjectStore, - maybe_create_object_store_from_uri(output_folder), - ) - _, _, output_folder_prefix = parse_uri(output_folder) - files_to_upload = os.listdir(local_output_folder) - - for file in files_to_upload: - assert not os.path.isdir(file) - remote_path = os.path.join(output_folder_prefix, file) - output_object_store.upload_object( - remote_path, - os.path.join(local_output_folder, file), - ) - - -def _args_str(original_args: Namespace) -> str: - """Create a string from the args to determine whether to reprocess. - - Args: - original_args (Namespace): Arguments to main function. - """ - # Take the arguments that influence the final result. - # reprocess and max_mds_writer_workers are not taken. - args = Namespace( - tokenizer_name=original_args.tokenizer, - output_folder=original_args.output_folder, - input_folder=original_args.input_folder, - concat_tokens=original_args.concat_tokens, - eos_text=original_args.eos_text, - bos_text=original_args.bos_text, - no_wrap=original_args.no_wrap, - compression=original_args.compression, - processes=original_args.processes, - ) - - return str(args) - - -def _configure_logging(logging_level: str): - """Configure logging. - - Args: - logging_level (str): Logging level. - """ - logging.basicConfig( - format= - f'%(asctime)s: [%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging_level = logging_level.upper() - logging.getLogger('llmfoundry').setLevel(logging_level) - logging.getLogger(__name__).setLevel(logging_level) - log.info(f'Logging level set to {logging_level}') - - if __name__ == '__main__': args = parse_args() - _configure_logging(args.logging_level) - - mosaicml_logger = maybe_create_mosaicml_logger() - - try: - convert_text_to_mds( - tokenizer_name=args.tokenizer, - output_folder=args.output_folder, - input_folder=args.input_folder, - concat_tokens=args.concat_tokens, - eos_text=args.eos_text, - bos_text=args.bos_text, - no_wrap=args.no_wrap, - compression=args.compression, - processes=args.processes, - reprocess=args.reprocess, - trust_remote_code=args.trust_remote_code, - args_str=_args_str(args), - ) - except Exception as e: - if mosaicml_logger is not None: - mosaicml_logger.log_exception(e) - raise e + convert_text_to_mds_from_args( + output_folder=args.output_folder, + input_folder=args.input_folder, + compression=args.compression, + concat_tokens=args.concat_tokens, + tokenizer_name=args.tokenizer, + bos_text=args.bos_text, + eos_text=args.eos_text, + use_tokenizer_eos=args.use_tokenizer_eos, + no_wrap=args.no_wrap, + processes=args.processes, + reprocess=args.reprocess, + trust_remote_code=args.trust_remote_code, + logging_level=args.logging_level, + ) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 8a0a5c104f..caafda4b87 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -1,445 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - -import logging -import os import sys -import time -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import torch -from composer.core import Callback -from composer.loggers.logger_destination import LoggerDestination -from composer.trainer import Trainer -from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig -from omegaconf import OmegaConf as om -from rich.traceback import install - -from llmfoundry.utils import ( - find_mosaicml_logger, - log_eval_analytics, - maybe_create_mosaicml_logger, -) - -install() -from llmfoundry.utils.builders import ( - add_metrics_to_eval_loaders, - build_callback, - build_composer_model, - build_evaluators, - build_logger, - build_tokenizer, -) -from llmfoundry.utils.config_utils import ( - EVAL_CONFIG_KEYS, - EvalConfig, - log_config, - make_dataclass_and_log_config, - process_init_device, -) -from llmfoundry.utils.registry_utils import import_file - -log = logging.getLogger(__name__) - - -def evaluate_model( - tokenizer: Dict[str, Any], - model_name: str, - model: Dict[str, Any], - dist_timeout: Union[float, int], - run_name: str, - seed: int, - icl_tasks: Union[str, List[Dict[str, Any]]], - max_seq_len: int, - device_eval_batch_size: int, - eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], - eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], - fsdp_config: Optional[Dict[str, Any]], - loggers: List[LoggerDestination], - python_log_level: Optional[str], - precision: str, - eval_gauntlet_df: Optional[pd.DataFrame], - eval_subset_num_batches: int, - icl_subset_num_batches: Optional[int], - callback_configs: Optional[Dict[str, Any]], - metadata: Optional[Dict[str, str]], - logged_config: Dict[str, Any], - should_log_config: bool = True, - load_path: Optional[str] = None, -): - log.info(f'Evaluating model: {model_name}') - # Build tokenizer and model - tokenizer_cfg = tokenizer - tokenizer_name = tokenizer_cfg['name'] - tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) - tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - - evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=device_eval_batch_size, - icl_seq_len=max_seq_len, - icl_subset_num_batches=icl_subset_num_batches, - ) - - # Callbacks - callbacks: List[Callback] = [ - build_callback(name=str(name), kwargs=callback_cfg) - for name, callback_cfg in callback_configs.items() - ] if callback_configs else [] - - if eval_gauntlet_callback is not None: - callbacks.append(eval_gauntlet_callback) - - if metadata is not None: - # Find the MosaicMLLogger - mosaicml_logger = find_mosaicml_logger(loggers) - - if mosaicml_logger is not None: - mosaicml_logger.log_metrics(metadata) - mosaicml_logger._flush_metadata(force_flush=True) - - if fsdp_config and model.get('load_in_8bit', False): - raise ValueError( - 'The FSDP config block is not supported when loading ' + - 'Hugging Face models in 8bit.', - ) - - init_context = process_init_device(model, fsdp_config) - - name = model.pop('name') - composer_model = build_composer_model( - name=name, - tokenizer=tokenizer, - init_context=init_context, - cfg=model, - ) - - # Now add the eval metrics - if eval_loader_config is not None: - train_metrics = composer_model.get_metrics(is_train=True) - evaluators = add_metrics_to_eval_loaders( - evaluators, - list(train_metrics.keys()), - ) - - if eval_gauntlet_df is None and eval_gauntlet_callback is not None: - eval_gauntlet_df = pd.DataFrame( - columns=['model_name'] + list(eval_gauntlet_callback.averages) + - [t['name'] for t in eval_gauntlet_callback.categories], - ) - - if name == 'mpt_causal_lm' and load_path is None: - raise ValueError( - 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' - + - ' Please check your yaml and the model_cfg to ensure that load_path is set.', - ) - - assert composer_model is not None - - log.info(f'Building trainer for {model_name}...') - trainer = Trainer( - run_name=run_name, - seed=seed, - model=composer_model, - callbacks=callbacks, - loggers=loggers, - precision=precision, - fsdp_config=fsdp_config, - load_path=load_path, - load_weights_only=True, - progress_bar=False, - log_to_console=True, - dist_timeout=dist_timeout, - python_log_level=python_log_level, - ) - - if should_log_config: - log.info('Evaluation config:') - log_config(logged_config) - - log.info(f'Starting eval for {model_name}...') - if torch.cuda.is_available(): - torch.cuda.synchronize() - a = time.time() - trainer.eval( - eval_dataloader=evaluators, - subset_num_batches=eval_subset_num_batches, - ) - if torch.cuda.is_available(): - torch.cuda.synchronize() - b = time.time() - - log.info(f'Ran {model_name} eval in: {b-a} seconds') - return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) - - -def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: - # Run user provided code if specified - for code_path in cfg.get('code_paths', []): - import_file(code_path) - - logged_cfg, eval_config = make_dataclass_and_log_config( - cfg, - EvalConfig, - EVAL_CONFIG_KEYS, - icl_tasks_required=True, - ) - - model_configs = eval_config.models - eval_gauntlet_config = eval_config.eval_gauntlet or eval_config.eval_gauntlet_str - - fsdp_config = eval_config.fsdp_config - - # Mandatory Evaluation Parameters - icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str - if icl_tasks is None: - raise ValueError('icl_tasks must be specified in the config') - - # Optional Evaluation Parameters with default values - eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name = eval_config.run_name if eval_config.run_name else default_run_name - - reproducibility.seed_all(eval_config.seed) - dist.initialize_dist(get_device(None), timeout=eval_config.dist_timeout) - - if eval_config.python_log_level is not None: - logging.basicConfig( - # Example of format string - # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging.getLogger('llmfoundry').setLevel( - eval_config.python_log_level.upper(), - ) - - # Default argument values for evaluate_model - eval_gauntlet_df = None - models_df = None - composite_scores = None - trainers = [] - - # Build loggers - loggers: List[LoggerDestination] = [ - build_logger(name, logger_cfg) - for name, logger_cfg in (eval_config.loggers or {}).items() - ] - - mosaicml_logger = find_mosaicml_logger(loggers) - if mosaicml_logger is None: - mosaicml_logger = maybe_create_mosaicml_logger() - # mosaicml_logger will be None if run isn't on MosaicML platform - if mosaicml_logger is not None: - loggers.append(mosaicml_logger) - - # mosaicml_logger will be None if the run isn't from the MosaicML platform - if mosaicml_logger is not None: - log_eval_analytics( - mosaicml_logger, - model_configs, - icl_tasks, - eval_gauntlet_config, - ) - - for model_cfg in model_configs: - - attn_config = model_cfg['model'].get('attn_config', None) - if attn_config is not None: - seq_parallel_world_size = attn_config.get( - 'seq_parallel_world_size', - None, - ) - if seq_parallel_world_size is not None and seq_parallel_world_size != 1: - raise ValueError( - 'Offline eval does not support sequence parallelism.', - ) - - (trainer, logger_keys, eval_gauntlet_callback, - eval_gauntlet_df) = evaluate_model( - dist_timeout=eval_config.dist_timeout, - run_name=run_name, - seed=eval_config.seed, - icl_tasks=icl_tasks, - max_seq_len=eval_config.max_seq_len, - device_eval_batch_size=eval_config.device_eval_batch_size, - eval_gauntlet_config=eval_gauntlet_config, - eval_loader_config=eval_loader_config, - fsdp_config=fsdp_config, - loggers=loggers, - python_log_level=eval_config.python_log_level, - precision=eval_config.precision, - eval_gauntlet_df=eval_gauntlet_df, - callback_configs=eval_config.callbacks, - eval_subset_num_batches=eval_config.eval_subset_num_batches, - icl_subset_num_batches=eval_config.icl_subset_num_batches, - metadata=eval_config.metadata, - logged_config=logged_cfg, - should_log_config=eval_config.log_config, - **model_cfg, - ) - trainers.append(trainer) - - if eval_gauntlet_callback is not None: - composite_scores = eval_gauntlet_callback.eval_after_all( - trainer.state, - trainer.logger, - ) - - benchmark_to_taxonomy = {} - if eval_gauntlet_callback is not None: - for t in eval_gauntlet_callback.categories: - for b in t['benchmarks']: - benchmark_to_taxonomy[b['name']] = t['name'] - - assert 'model_name' in model_cfg, 'model_name must be specified in model config' - model_results = calculate_markdown_results( - logger_keys, - trainer, - benchmark_to_taxonomy, - model_cfg['model_name'], - ) - - if models_df is None: - models_df = model_results - else: - models_df = pd.concat([models_df, model_results], ignore_index=True) - - if eval_gauntlet_df is not None and eval_gauntlet_callback is not None: - assert composite_scores is not None - row = {'model_name': model_cfg['model_name']} - row.update({ - k.split('/')[-1]: v for k, v in composite_scores.items() - }) - eval_gauntlet_df = pd.concat([ - eval_gauntlet_df, - pd.DataFrame([row]), - ], - ignore_index=True) - - print(f'Printing gauntlet results for all models') - - print( - eval_gauntlet_df.sort_values( - list(eval_gauntlet_callback.averages.keys())[0], - ascending=False, - ).to_markdown(index=False), - ) - print(f'Printing complete results for all models') - assert models_df is not None - print(models_df.to_markdown(index=False)) - - trainer.close() - - return trainers, eval_gauntlet_df - - -def calculate_markdown_results( - logger_keys: List[str], - trainer: Trainer, - benchmark_to_taxonomy: Dict[str, str], - model_name: str, -): - results = {} - - for key in logger_keys: - # dl_name is either 2-tuple (benchmark_name, num_fewshot) - # or 3-tuple (benchmark_name, num_fewshot, subcategory) - dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1] - if 'Accuracy' not in metric_name: - continue - - metric = trainer.state.eval_metrics.get('/'.join(dl_name), - {}).get(metric_name, None) - - if metric is None: - continue - if dl_name[1] not in results: - results[dl_name[1]] = {} - - if dl_name[0] not in results[dl_name[1]]: - results[dl_name[1]][dl_name[0]] = {} - - if metric_name not in results[dl_name[1]][dl_name[0]]: - results[dl_name[1]][dl_name[0]][metric_name] = [] - - results[dl_name[1]][dl_name[0]][metric_name].append({ - 'val': metric.compute(), - 'subcat': dl_name[-1] if len(dl_name) == 3 else 'no_subcat', - }) - - df = pd.DataFrame( - columns=[ - 'Category', - 'Benchmark', - 'Subtask', - 'Accuracy', - 'Number few shot', - 'Model', - ], - ) - - for num_shot in results: - for benchmark in results[num_shot]: - for metric in results[num_shot][benchmark]: - subscores = results[num_shot][benchmark][metric] - if len(subscores) == 1: - row = { - 'Category': benchmark_to_taxonomy.get(benchmark, ''), - 'Benchmark': benchmark, - 'Subtask': None, - 'Accuracy': subscores[0]['val'], - 'Number few shot': num_shot, - 'Model': model_name, - } - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) - else: - row = { - 'Category': - benchmark_to_taxonomy.get(benchmark, ''), - 'Benchmark': - benchmark, - 'Subtask': - 'Average', - 'Accuracy': - sum(s['val'] for s in subscores) / len(subscores), - 'Number few shot': - num_shot, - 'Model': - model_name, - } - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) - for sub in subscores: - row = { - 'Category': - benchmark_to_taxonomy.get(benchmark, ''), - 'Benchmark': - None, - 'Subtask': - sub['subcat'], - 'Accuracy': - sub['val'], - 'Number few shot': - num_shot, - 'Model': - model_name, - } - df = pd.concat([df, pd.DataFrame([row])], - ignore_index=True) - return df +from llmfoundry.command_utils import eval_from_yaml if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] - with open(yaml_path) as f: - yaml_cfg = om.load(f) - cli_cfg = om.from_cli(args_list) - cfg = om.merge(yaml_cfg, cli_cfg) - assert isinstance(cfg, DictConfig) - main(cfg) + eval_from_yaml(yaml_path, args_list) diff --git a/scripts/eval/yamls/long_context_tasks.yaml b/scripts/eval/yamls/long_context_tasks.yaml index daf958a340..153e3b9df6 100644 --- a/scripts/eval/yamls/long_context_tasks.yaml +++ b/scripts/eval/yamls/long_context_tasks.yaml @@ -3,7 +3,7 @@ icl_tasks: label: kv_pairs_beginning_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 2048 @@ -13,7 +13,7 @@ icl_tasks: label: kv_pairs_middle_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 2048 @@ -23,7 +23,7 @@ icl_tasks: label: kv_pairs_end_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 2048 @@ -33,7 +33,7 @@ icl_tasks: label: kv_pairs_beginning_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 4096 @@ -43,7 +43,7 @@ icl_tasks: label: kv_pairs_middle_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 4096 @@ -53,7 +53,7 @@ icl_tasks: label: kv_pairs_end_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 4096 @@ -63,7 +63,7 @@ icl_tasks: label: kv_pairs_beginning_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 8192 @@ -73,7 +73,7 @@ icl_tasks: label: kv_pairs_middle_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 8192 @@ -83,7 +83,7 @@ icl_tasks: label: kv_pairs_end_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: kv_pairs context_length: 8192 @@ -93,7 +93,7 @@ icl_tasks: label: wikiqa_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: wikiqa context_length: 2048 @@ -102,7 +102,7 @@ icl_tasks: label: wikiqa_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: wikiqa context_length: 2048 @@ -111,7 +111,7 @@ icl_tasks: label: wikiqa_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: wikiqa context_length: 2048 @@ -120,7 +120,7 @@ icl_tasks: label: hotpotqa_beginning_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 2048 @@ -130,7 +130,7 @@ icl_tasks: label: hotpotqa_middle_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 2048 @@ -140,7 +140,7 @@ icl_tasks: label: hotpotqa_end_2k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 2048 @@ -150,7 +150,7 @@ icl_tasks: label: hotpotqa_beginning_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 4096 @@ -160,7 +160,7 @@ icl_tasks: label: hotpotqa_middle_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 4096 @@ -170,7 +170,7 @@ icl_tasks: label: hotpotqa_end_4k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 4096 @@ -180,7 +180,7 @@ icl_tasks: label: hotpotqa_beginning_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 8192 @@ -190,7 +190,7 @@ icl_tasks: label: hotpotqa_middle_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 8192 @@ -200,7 +200,7 @@ icl_tasks: label: hotpotqa_end_8k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 8192 @@ -210,7 +210,7 @@ icl_tasks: label: hotpotqa_beginning_16k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 16384 @@ -220,7 +220,7 @@ icl_tasks: label: hotpotqa_beginning_32k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 32768 @@ -230,7 +230,7 @@ icl_tasks: label: hotpotqa_beginning_64k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 65536 @@ -240,7 +240,7 @@ icl_tasks: label: hotpotqa_middle_16k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 16384 @@ -250,7 +250,7 @@ icl_tasks: label: hotpotqa_middle_32k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 32768 @@ -260,7 +260,7 @@ icl_tasks: label: hotpotqa_middle_64k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 65536 @@ -270,7 +270,7 @@ icl_tasks: label: hotpotqa_end_16k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 16384 @@ -280,7 +280,7 @@ icl_tasks: label: hotpotqa_end_32k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 32768 @@ -290,7 +290,7 @@ icl_tasks: label: hotpotqa_end_64k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 65536 @@ -300,7 +300,7 @@ icl_tasks: label: kv_pairs_beginning_16k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 16384 @@ -310,7 +310,7 @@ icl_tasks: label: kv_pairs_beginning_32k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 32768 @@ -320,7 +320,7 @@ icl_tasks: label: kv_pairs_beginning_64k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 65536 @@ -330,7 +330,7 @@ icl_tasks: label: kv_pairs_middle_16k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 16384 @@ -340,7 +340,7 @@ icl_tasks: label: kv_pairs_middle_32k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 32768 @@ -350,7 +350,7 @@ icl_tasks: label: kv_pairs_middle_64k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 65536 @@ -360,7 +360,7 @@ icl_tasks: label: kv_pairs_end_16k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 16384 @@ -370,7 +370,7 @@ icl_tasks: label: kv_pairs_end_32k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 32768 @@ -380,7 +380,7 @@ icl_tasks: label: kv_pairs_end_64k dataset_uri: hf://mosaicml/long_context_eval num_fewshot: [0] - icl_task_type: question_answering + icl_task_type: generation_task_with_answers hf_loading_vars: name: hotpotqa context_length: 65536 diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index e992371c32..7fb3d2af46 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -364,9 +364,7 @@ def main(args: Namespace) -> None: except Exception as e: raise RuntimeError( 'If you are having auth problems, try logging in via `huggingface-cli login` ' - + - 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' - + + + 'or by setting the environment variable `export HF_TOKEN=... ' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e @@ -389,9 +387,7 @@ def main(args: Namespace) -> None: raise RuntimeError( 'Unable to load HF model. ' + 'If you are having auth problems, try logging in via `huggingface-cli login` ' - + - 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' - + + + 'or by setting the environment variable `export HF_TOKEN=... ' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index eab46d7a69..b2e758b4ce 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -200,7 +200,7 @@ def main(args: Namespace) -> None: except Exception as e: raise RuntimeError( 'If you are having auth problems, try logging in via `huggingface-cli login` ' +\ - 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' +\ + 'or by setting the environment variable `export HF_TOKEN=... ' +\ 'using your access token from https://huggingface.co/settings/tokens.', ) from e @@ -236,9 +236,7 @@ def main(args: Namespace) -> None: raise RuntimeError( 'Unable to load HF model. ' + 'If you are having auth problems, try logging in via `huggingface-cli login` ' - + - 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' - + + + 'or by setting the environment variable `export HF_TOKEN=... ' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 4e36c35e29..91b0c5a037 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -27,7 +27,8 @@ download_from_oras, ) -HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' +DEPRECATED_HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' +HF_TOKEN_ENV_VAR = 'HF_TOKEN' logging.basicConfig( format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', @@ -42,7 +43,10 @@ def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( '--token', type=str, - default=os.getenv(HF_TOKEN_ENV_VAR), + default=os.getenv( + HF_TOKEN_ENV_VAR, + os.getenv(DEPRECATED_HF_TOKEN_ENV_VAR), + ), ) diff --git a/scripts/train/train.py b/scripts/train/train.py index c9e2d67bf4..728010d13a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,558 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import gc -import logging -import os import sys -import time -import warnings -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.distributed -from composer import ComposerModel, Trainer -from composer.core.callback import Callback -from composer.profiler import ( - JSONTraceHandler, - Profiler, - TraceHandler, - cyclic_schedule, -) -from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig -from omegaconf import OmegaConf as om - -from llmfoundry.callbacks import AsyncEval -from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.eval.metrics.nlp import InContextLearningMetric -from llmfoundry.layers_registry import ffns_with_megablocks -from llmfoundry.utils import ( - find_mosaicml_logger, - log_train_analytics, - maybe_create_mosaicml_logger, -) -from llmfoundry.utils.builders import ( - add_metrics_to_eval_loaders, - build_algorithm, - build_callback, - build_composer_model, - build_evaluators, - build_logger, - build_optimizer, - build_scheduler, - build_tokenizer, -) -from llmfoundry.utils.config_utils import ( - TRAIN_CONFIG_KEYS, - TrainConfig, - log_config, - log_dataset_uri, - make_dataclass_and_log_config, - pop_config, - process_init_device, - update_batch_size_info, -) -from llmfoundry.utils.exceptions import ( - BaseContextualError, - EvalDataLoaderLocation, - TrainDataLoaderLocation, -) -from llmfoundry.utils.registry_utils import import_file - -log = logging.getLogger(__name__) - - -def validate_config(train_config: TrainConfig): - """Validates compatible model and dataloader selection.""" - # Validate the rest of the config - loaders = [train_config.train_loader] - if train_config.eval_loaders is not None: - for loader in (train_config.eval_loaders or []): # pyright - if 'label' not in loader or loader['label'] is None: - raise ValueError( - 'When specifying multiple evaluation datasets, each one must include the \ - `label` attribute.', - ) - loaders.append(loader) - if train_config.eval_loader is not None: - loaders.append(train_config.eval_loader) - for loader in loaders: - if loader['name'] == 'text': - if train_config.model['name'] == 'hf_t5': - raise ValueError( - f'Model type "{train_config.model["name"]}" is not supported when using the "text " ' +\ - f'dataloader. Only finetuning is supported.') - - if train_config.icl_tasks is not None or train_config.icl_tasks_str is not None: - if train_config.model['name'] == 'hf_t5': - raise ValueError( - 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".', - ) - - if ( - train_config.model.get('fc_type', 'torch') != 'te' and - 'te' not in train_config.model.get('ffn_config', - {}).get('ffn_type', 'mptmlp') and - 'fp8' in train_config.precision - ): - warnings.warn( - "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " - + - "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision.", - ) - - if ( - train_config.model.get('fc_type', 'torch') == 'te' or 'te' - in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') - ): - fsdp_config = train_config.fsdp_config - act_ckpt = fsdp_config.get( - 'activation_checkpointing', - False, - ) if fsdp_config else False - act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', - False, - ) if fsdp_config else False - if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: - warnings.warn( - '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = True`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.', - ) - assert train_config.fsdp_config is not None # pyright (this is known because fsdp_config is not None) - train_config.fsdp_config['activation_checkpointing_reentrant' - ] = False - - if train_config.model.get('ffn_config', - {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': - warnings.warn( - '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + - 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.', - ) - torch._dynamo.config.suppress_errors = True # type: ignore (third-party) - - if train_config.model.get('load_in_8bit', False): - raise ValueError( - '`load_in_8bit` is only supported for evaluation rather than training.', - ) - - if train_config.model.get('ffn_config', {}).get( - 'ffn_type', - 'mptmlp', - ) in ffns_with_megablocks: - moe_world_size = train_config.model.get('ffn_config', - {}).get('moe_world_size', 1) - use_orig_params = train_config.fsdp_config.get( - 'use_orig_params', - True, - ) if train_config.fsdp_config is not None else True - if moe_world_size > 1 and not use_orig_params: - raise ValueError( - f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.', - ) - - attn_config = train_config.model.get('attn_config', None) - if attn_config is not None: - seq_parallel_world_size = attn_config.get( - 'seq_parallel_world_size', - None, - ) - if seq_parallel_world_size is not None: - raise ValueError('Training does not support sequence parallelism.') - - -def _log_num_params(model: ComposerModel, logged_cfg: Dict[str, Any]): - # Log number of parameters - if hasattr(model, 'n_total_params'): - n_params = model.n_total_params - n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. - else: - n_params = sum(p.numel() for p in model.parameters()) - n_trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - if hasattr(model, 'n_active_params'): - n_active_params = model.n_active_params - else: - n_active_params = n_params - logged_cfg.update({ - 'n_params': n_params, - 'n_active_params': n_active_params, - 'n_trainable_params': n_trainable_params, - }) - - -def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): - """Initialize distributed and test setup with a barrier. - - Args: - dist_timeout (Union[int, float]): Timeout for initializing the process group - """ - log.debug('Initializing dist with device...') - dist.initialize_dist(get_device(None), timeout=dist_timeout) - log.debug('Testing barrier with device...') - dist.barrier() - log.debug('Barrier test passed with device.') - - -def main(cfg: DictConfig) -> Trainer: - code_paths = cfg.get('code_paths', []) - # Import any user provided code - for code_path in code_paths: - import_file(code_path) - - logged_cfg, train_cfg = make_dataclass_and_log_config( - cfg, - TrainConfig, - TRAIN_CONFIG_KEYS, - transforms=[update_batch_size_info], - ) - - # Set logging level - if train_cfg.python_log_level is not None: - logging.basicConfig( - # Example of format string - # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging.getLogger('llmfoundry').setLevel( - train_cfg.python_log_level.upper(), - ) # Foundry module - logging.getLogger(__name__).setLevel( - train_cfg.python_log_level.upper(), - ) # Train script - - _initialize_dist_with_barrier(dist_timeout=train_cfg.dist_timeout) - - # Filter deprecation warning from torch internal usage - warnings.filterwarnings( - action='ignore', - category=UserWarning, - message= - 'torch.distributed.*_base is a private function and will be deprecated.*', - ) - - # Check for incompatibilities between the model and data loaders - validate_config(train_cfg) - - cuda_alloc_conf = [] - # Get max split size mb - max_split_size_mb: Optional[int] = train_cfg.max_split_size_mb - if max_split_size_mb is not None: - cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') - - # Expandable segments - if train_cfg.expandable_segments: - cuda_alloc_conf.append('expandable_segments:True') - - if len(cuda_alloc_conf) > 0: - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) - - # Set CUDA lazy loading - # This can save a bit of memory if not all modules are needed - cuda_load_lazy: bool = train_cfg.cuda_load_lazy - if cuda_load_lazy: - os.environ['CUDA_MODULE_LOADING'] = 'LAZY' - - # Set seed first - seed: int = train_cfg.seed - reproducibility.seed_all(seed) - - # Mandatory model training configs - model_config = train_cfg.model - train_loader_config = train_cfg.train_loader - - # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config - - eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders - icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str - eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str - - # Optional parameters will be set to default values if not specified. - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name - is_state_dict_sharded: bool = ( - fsdp_config.get('state_dict_type', 'full') == 'sharded' - ) if fsdp_config else False - save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt' - save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' - - # Enable autoresume from model checkpoints if possible - autoresume_default: bool = False - if logged_cfg.get('run_name', None) is not None \ - and train_cfg.save_folder is not None \ - and not train_cfg.save_overwrite \ - and not train_cfg.save_weights_only: - autoresume_default = True - - if not train_cfg.autoresume and autoresume_default: - log.info( - 'As run_name, save_folder, and save_latest_filename are set, \ - changing autoresume default to True...', - ) - - # Warn if fsdp is enabled but user only has 1 GPU - if dist.get_world_size() == 1 and fsdp_config is not None: - warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.', - ) - fsdp_config = None - - # Initialize context - init_context = process_init_device(model_config, fsdp_config) - logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) - - # Build tokenizer - log.info('Building tokenizer...') - tokenizer_name = train_cfg.tokenizer['name'] - tokenizer_kwargs = train_cfg.tokenizer.get('kwargs', {}) - tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - - # Scheduler - scheduler_name: str = train_cfg.scheduler.pop('name') - scheduler = build_scheduler(scheduler_name, train_cfg.scheduler) - - # Loggers - loggers = [ - build_logger(str(name), logger_cfg) - for name, logger_cfg in train_cfg.loggers.items() - ] if train_cfg.loggers else [] - - mosaicml_logger = find_mosaicml_logger(loggers) - if mosaicml_logger is None: - mosaicml_logger = maybe_create_mosaicml_logger() - if mosaicml_logger is not None: - # mosaicml_logger will be None if run isn't on MosaicML platform - loggers.append(mosaicml_logger) - - if train_cfg.metadata is not None: - # Flatten the metadata for logging - logged_cfg.pop('metadata', None) - logged_cfg.update(train_cfg.metadata, merge=True) - if mosaicml_logger is not None: - mosaicml_logger.log_metrics(train_cfg.metadata) - mosaicml_logger._flush_metadata(force_flush=True) - - # Profiling - profiler: Optional[Profiler] = None - profiler_cfg = train_cfg.profiler - if profiler_cfg: - profiler_schedule_cfg: Dict = pop_config( - profiler_cfg, - 'schedule', - must_exist=True, - ) - profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) - # Only support json trace handler - profiler_trace_handlers: List[TraceHandler] = [] - profiler_trace_cfg: Optional[Dict] = pop_config( - profiler_cfg, - 'json_trace_handler', - must_exist=False, - default_value=None, - ) - if profiler_trace_cfg: - profiler_trace_handlers.append( - JSONTraceHandler(**profiler_trace_cfg), - ) - profiler = Profiler( - **profiler_cfg, - trace_handlers=profiler_trace_handlers, - schedule=profiler_schedule, - ) - - callback_configs = train_cfg.callbacks or {} - - # Callbacks - callbacks: List[Callback] = [ - build_callback( - name=str(name), - kwargs=callback_cfg, - train_config=logged_cfg, - ) for name, callback_cfg in callback_configs.items() - ] - - use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) - - algorithm_configs = train_cfg.algorithms or {} - - # Algorithms - algorithms = [ - build_algorithm(str(name), algorithm_cfg) - for name, algorithm_cfg in algorithm_configs.items() - ] - - # Dataloaders - log.info('Building train loader...') - try: - train_loader = build_dataloader( - train_loader_config, - tokenizer, - train_cfg.device_train_batch_size, - ) - except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = TrainDataLoaderLocation - mosaicml_logger.log_exception(e) - raise e - - if mosaicml_logger is not None: - mosaicml_logger.log_metrics({'data_validated': time.time()}) - - ## Evaluation - if use_async_eval: - evaluators = [] - if train_cfg.eval_first: - warnings.warn( - 'AsyncEval callback does not support eval_first=True. Ignoring.', - ) - train_cfg.eval_first = False - - else: - try: - log.info('Building eval loader...') - eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len - evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks_config, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=train_cfg.device_eval_batch_size, - icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=train_cfg.icl_subset_num_batches, - ) - if eval_gauntlet_callback is not None: - callbacks.append(eval_gauntlet_callback) - except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = EvalDataLoaderLocation - mosaicml_logger.log_exception(e) - raise e - - if mosaicml_logger is not None: - log_train_analytics( - mosaicml_logger, - model_config, - train_loader_config, - eval_loader_config, - train_cfg.callbacks, - tokenizer_name, - train_cfg.load_path, - icl_tasks_config, - eval_gauntlet_config, - ) - # Build Model - log.info('Initializing model...') - name = model_config.pop('name') - assert isinstance(name, str) - assert isinstance(model_config, dict) - model = build_composer_model( - name=name, - tokenizer=tokenizer, - init_context=init_context, - master_weights_dtype=model_config.get('master_weights_dtype', None), - cfg=model_config, - ) - - _log_num_params(model, logged_cfg) - - # Optimizer - optimizer_name: str = train_cfg.optimizer.pop('name') - optimizer_cfg = train_cfg.optimizer - optimizer = build_optimizer(model, optimizer_name, optimizer_cfg) - - # Now add the eval metrics - try: - if eval_loader_config is not None and not use_async_eval: - eval_metrics = model.get_metrics(is_train=False) - non_icl_metrics = [ - metric_name for metric_name, metric in eval_metrics.items() - if not isinstance(metric, InContextLearningMetric) - ] - evaluators = add_metrics_to_eval_loaders( - evaluators, - non_icl_metrics, - ) - except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = EvalDataLoaderLocation - mosaicml_logger.log_exception(e) - raise e - - compile_config = train_cfg.compile_config - # Build the Trainer - log.info('Building trainer...') - trainer = Trainer( - run_name=run_name, - seed=seed, - model=model, - train_dataloader=train_loader, - eval_dataloader=evaluators, - optimizers=optimizer, - schedulers=scheduler, - max_duration=train_cfg.max_duration, - eval_interval=train_cfg.eval_interval, - eval_subset_num_batches=train_cfg.eval_subset_num_batches, - progress_bar=train_cfg.progress_bar, - log_to_console=train_cfg.log_to_console, - console_log_interval=train_cfg.console_log_interval, - loggers=loggers, - callbacks=callbacks, - precision=train_cfg.precision, - algorithms=algorithms, - device_train_microbatch_size=train_cfg.device_train_microbatch_size, - fsdp_config=fsdp_config, - save_folder=train_cfg.save_folder, - save_filename=save_filename, - save_latest_filename=save_latest_filename, - save_interval=train_cfg.save_interval, - save_num_checkpoints_to_keep=train_cfg.save_num_checkpoints_to_keep, - save_overwrite=train_cfg.save_overwrite, - save_weights_only=train_cfg.save_weights_only, - load_path=train_cfg.load_path, - load_weights_only=train_cfg.load_weights_only, - load_strict_model_weights=train_cfg.load_strict_model_weights, - load_ignore_keys=train_cfg.load_ignore_keys, - save_ignore_keys=train_cfg.save_ignore_keys, - autoresume=train_cfg.autoresume, - python_log_level=train_cfg.python_log_level, - dist_timeout=train_cfg.dist_timeout, - profiler=profiler, - compile_config=compile_config, - ) - - if train_cfg.log_config: - log.info('Logging config') - log_config(logged_cfg) - log_dataset_uri(logged_cfg) - torch.cuda.empty_cache() - gc.collect() - - # Eval first if requested - if train_cfg.eval_first and trainer.state.timestamp.batch.value == 0: - trainer.eval() - - log.info('Starting training...') - trainer.fit() - - log.info('Done.') - return trainer +from llmfoundry.command_utils import train_from_yaml if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] - - # Disable resolving environment variables through omegaconf. - om.clear_resolver('oc.env') - - # Load yaml and cli arguments. - with open(yaml_path) as f: - yaml_cfg = om.load(f) - cli_cfg = om.from_cli(args_list) - cfg = om.merge(yaml_cfg, cli_cfg) - om.resolve(cfg) - assert isinstance(cfg, DictConfig) - main(cfg) + train_from_yaml(yaml_path, args_list) diff --git a/setup.py b/setup.py index 78182976d4..19e5cee2d6 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import copy import os -import re +from typing import Any, Dict, Mapping import setuptools from setuptools import setup @@ -15,17 +15,15 @@ _REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) _PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) -# Read the repo version +# Read the llm-foundry version # We can't use `.__version__` from the library since it's not installed yet -with open(os.path.join(_PACKAGE_REAL_PATH, '__init__.py')) as f: +version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') +with open(version_path, encoding='utf-8') as f: + version_globals: Dict[str, Any] = {} + version_locals: Mapping[str, object] = {} content = f.read() -# regex: '__version__', whitespace?, '=', whitespace, quote, version, quote -# we put parens around the version so that it becomes elem 1 of the match -expr = re.compile( - r"""^__version__\s*=\s*['"]([0-9]+\.[0-9]+\.[0-9]+(?:\.\w+)?)['"]""", - re.MULTILINE, -) -repo_version = expr.findall(content)[0] + exec(content, version_globals, version_locals) + repo_version = str(version_locals['__version__']) # Use repo README for PyPi description with open('README.md', 'r', encoding='utf-8') as fh: @@ -54,27 +52,27 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,oci,gcs]>=0.22.0,<0.23', - 'mlflow>=2.12.1,<2.13', - 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.40,<4.41', - 'mosaicml-streaming>=0.7.6,<0.8', + 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24', + 'mlflow>=2.14.1,<2.15', + 'accelerate>=0.25,<0.34', # for HF inference `device_map` + 'transformers>=4.43.2,<4.44', + 'mosaicml-streaming>=0.8.0,<0.9', 'torch>=2.3.0,<2.4', 'datasets>=2.19,<2.20', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data - 'sentencepiece==0.1.97', - 'einops==0.7.0', + 'sentencepiece==0.2.0', + 'einops==0.8.0', 'omegaconf>=2.2.3,<3', 'slack-sdk<4', 'mosaicml-cli>=0.6.10,<1', - 'onnx==1.14.0', - 'onnxruntime==1.15.1', + 'onnx==1.16.2', + 'onnxruntime==1.18.1', 'boto3>=1.21.45,<2', - 'huggingface-hub>=0.19.0,<0.23', + 'huggingface-hub>=0.19.0,<0.25', 'beautifulsoup4>=4.12.2,<5', # required for model download utils 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', - 'typer[all]<1', + 'typer<1', ] extra_deps = {} @@ -84,7 +82,7 @@ 'pre-commit>=3.4.0,<4', 'pytest>=7.2.1,<8', 'pytest_codeblocks>=0.16.1,<0.17', - 'pytest-cov>=4,<5', + 'pytest-cov>=4,<6', 'pyright==1.1.256', 'toml>=0.10.2,<0.11', 'packaging>=21,<23', @@ -92,25 +90,25 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.22.0,<0.23', + 'mosaicml[databricks]>=0.23.4,<0.24', 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', 'lz4>=4,<5', ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.22.0,<0.23', + 'mosaicml[tensorboard]>=0.23.4,<0.24', ] # Flash 2 group kept for backwards compatibility extra_deps['gpu-flash2'] = [ - 'flash-attn==2.5.8', + 'flash-attn>=2.5.8,<3', ] extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2']) extra_deps['peft'] = [ - 'mosaicml[peft]>=0.22.0,<0.23', + 'mosaicml[peft]>=0.23.4,<0.24', ] extra_deps['openai'] = [ @@ -123,6 +121,11 @@ 'grouped-gemm==0.1.4', ] +extra_deps['databricks-serverless'] = { + dep for key, deps in extra_deps.items() for dep in deps + if 'gpu' not in key and 'megablocks' not in key and + 'databricks-connect' not in dep +} extra_deps['all-cpu'] = { dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key and 'megablocks' not in key diff --git a/tests/a_scripts/data_prep/test_convert_dataset_hf.py b/tests/a_scripts/data_prep/test_convert_dataset_hf.py index 4c5d1a6bba..e09c54ca70 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_hf.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_hf.py @@ -2,29 +2,26 @@ # SPDX-License-Identifier: Apache-2.0 import os -from argparse import Namespace from pathlib import Path -from scripts.data_prep.convert_dataset_hf import main as main_hf +from llmfoundry.command_utils import convert_dataset_hf def test_download_script_from_api(tmp_path: Path): # test calling it directly path = os.path.join(tmp_path, 'my-copy-c4-1') - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': ['val_xsmall'], - 'out_root': path, - 'compression': None, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=['val_xsmall'], + out_root=path, + compression=None, + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, + tokenizer=None, + tokenizer_kwargs={}, ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_dataset_json.py b/tests/a_scripts/data_prep/test_convert_dataset_json.py index 912e44cd0c..4f70a35637 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_json.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_json.py @@ -2,28 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 import os -from argparse import Namespace from pathlib import Path -from scripts.data_prep.convert_dataset_json import main as main_json +from llmfoundry.command_utils import convert_dataset_json def test_json_script_from_api(tmp_path: Path): # test calling it directly path = os.path.join(tmp_path, 'my-copy-arxiv-1') - main_json( - Namespace( - **{ - 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': path, - 'compression': None, - 'split': 'train', - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_json( + path='scripts/data_prep/example_data/arxiv.jsonl', + out_root=path, + compression=None, + split='train', + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index e4619b8a56..e623467bf7 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -1,17 +1,15 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# copyright 2022 mosaicml llm foundry authors -# spdx-license-identifier: apache-2.0 - import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, mock_open, patch -from scripts.data_prep.convert_delta_to_json import ( +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( download, fetch_DT, + format_tablename, iterative_combine_jsons, run_query, ) @@ -19,11 +17,19 @@ class TestConvertDeltaToJsonl(unittest.TestCase): - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sdk.WorkspaceClient', + ) def test_stream_delta_to_json( self, mock_workspace_client: Any, @@ -32,19 +38,15 @@ def test_stream_delta_to_json( mock_makedirs: Any, mock_sql_connect: Any, ): - - args = MagicMock() - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' - args.DATABRICKS_HOST = 'test_host' - args.DATABRICKS_TOKEN = 'test_token' - args.http_path = 'test_path' - args.batch_size = 1000 - args.partitions = 1 - args.cluster_id = '1234' - args.debug = False - args.use_serverless = False - args.json_output_filename = 'combined.jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' + DATABRICKS_HOST = 'test_host' + DATABRICKS_TOKEN = 'test_token' + http_path = 'test_path' + batch_size = 1000 + cluster_id = '1234' + use_serverless = False + json_output_filename = 'combined.jsonl' mock_cluster_get = MagicMock() mock_cluster_get.return_value = MagicMock( @@ -52,7 +54,17 @@ def test_stream_delta_to_json( ) mock_workspace_client.return_value.clusters.get = mock_cluster_get - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + batch_size=batch_size, + json_output_filename=json_output_filename, + ) mock_sql_connect.assert_called_once_with( server_hostname='test_host', http_path='test_path', @@ -65,7 +77,9 @@ def test_stream_delta_to_json( '/path/to/jsonl/combined.jsonl', ) - @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.listdir', + ) @patch( 'builtins.open', new_callable=mock_open, @@ -101,7 +115,9 @@ def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): """ self.assertEqual(mock_file().write.call_count, 2) - @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + @patch( + 'pyspark.sql.SparkSession', + ) def test_run_query_dbconnect(self, mock_spark: Any): method = 'dbconnect' mock_cursor = None @@ -117,7 +133,9 @@ def test_run_query_dbconnect(self, mock_spark: Any): mock_spark.sql.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') - @patch('scripts.data_prep.convert_delta_to_json.Cursor') + @patch( + 'databricks.sql.client.Cursor', + ) def test_run_query_dbsql(self, mock_cursor: Any): method = 'dbsql' mock_cursor.fetchall.return_value = 'result' @@ -133,14 +151,18 @@ def test_run_query_dbsql(self, mock_cursor: Any): mock_cursor.execute.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') - @patch('scripts.data_prep.convert_delta_to_json.requests.get') - @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') @patch( - 'scripts.data_prep.convert_delta_to_json.os.path.join', + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.requests.get', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.pd.DataFrame.to_json', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.path.join', return_value='/fake/path/part_1.jsonl', ) @patch( - 'scripts.data_prep.convert_delta_to_json.time.sleep', + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.time.sleep', ) # Mock sleep to speed up the test def test_download_success( self, @@ -173,12 +195,22 @@ def test_download_success( mock_get.assert_called_once_with('http://fakeurl.com/data') - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_dbconnect_called( self, mock_fetch: Any, @@ -188,17 +220,14 @@ def test_dbconnect_called( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = None - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = None + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -208,19 +237,37 @@ def test_dbconnect_called( ) # Mock return value for getOrCreate mock_databricks_session.builder.remote.return_value = mock_remote - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_databricks_session.builder.remote.assert_called_once_with( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id, + host=DATABRICKS_HOST, + token=DATABRICKS_TOKEN, + cluster_id=cluster_id, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_dbr13( self, mock_fetch: Any, @@ -230,34 +277,49 @@ def test_sqlconnect_called_dbr13( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( - server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + server_hostname=DATABRICKS_HOST, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_dbr14( self, mock_fetch: Any, @@ -267,34 +329,49 @@ def test_sqlconnect_called_dbr14( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( - server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + server_hostname=DATABRICKS_HOST, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_https( self, mock_fetch: Any, @@ -304,34 +381,49 @@ def test_sqlconnect_called_https( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'https://test-host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( server_hostname='test-host', - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_serverless( self, mock_fetch: Any, @@ -341,21 +433,40 @@ def test_serverless( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'https://test-host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = True + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'token' + use_serverless = True mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) assert not mock_sql_connect.called assert not mock_databricks_session.builder.remote.called + + def test_format_tablename(self): + self.assertEqual( + format_tablename('test_catalog.hyphenated-schema.test_table'), + '`test_catalog`.`hyphenated-schema`.`test_table`', + ) + self.assertEqual( + format_tablename('catalog.schema.table'), + '`catalog`.`schema`.`table`', + ) + self.assertEqual( + format_tablename('hyphenated-catalog.schema.test_table'), + '`hyphenated-catalog`.`schema`.`test_table`', + ) diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index df4309e13d..6ba14d62e4 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -9,16 +9,11 @@ from typing import Callable, Iterable, List from unittest.mock import Mock, patch -import numpy as np import pytest from streaming import StreamingDataset from transformers import AutoTokenizer -from llmfoundry.utils.exceptions import ( - InputFolderMissingDataError, - OutputFolderNotEmptyError, -) -from scripts.data_prep.convert_text_to_mds import ( +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( DONE_FILENAME, convert_text_to_mds, download_and_convert, @@ -26,6 +21,11 @@ merge_shard_groups, write_done_file, ) +from llmfoundry.utils.exceptions import ( + DatasetTooSmallError, + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) class MockObjectStore(): @@ -84,15 +84,15 @@ def _assert_files_exist(prefix: str, files: List[str]): @pytest.mark.parametrize('processes', [1, 2, 3]) @patch.object(ProcessPoolExecutor, 'map', new=Mock(wraps=_mock_map)) @patch( - 'scripts.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', ) -@patch('scripts.data_prep.convert_text_to_mds.parse_uri') +@patch('llmfoundry.command_utils.data_prep.convert_text_to_mds.parse_uri') @patch( - 'scripts.data_prep.convert_text_to_mds.download_and_convert', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.download_and_convert', wraps=download_and_convert, ) @patch( - 'scripts.data_prep.convert_text_to_mds.merge_shard_groups', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.merge_shard_groups', wraps=merge_shard_groups, ) def test_single_and_multi_process( @@ -194,7 +194,7 @@ def call_convert_text_to_mds() -> None: n_tokens = 0 for i in range(dataset.num_samples): sample = dataset[i] - tokens = np.frombuffer(sample['tokens'], dtype=int) + tokens = sample['tokens'] if i == 0: # For the first sample, check that the decoded sample matches the text_content decoded = tokenizer.decode(tokens) assert decoded == text_content[:len(decoded)] @@ -268,6 +268,28 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path): ) +def test_dataset_too_small(tmp_path: pathlib.Path): + input_folder = tmp_path / 'input' + os.makedirs(input_folder, exist_ok=True) + with open(input_folder / 'test.txt', 'w') as f: + f.write('a') + with pytest.raises(DatasetTooSmallError): + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(tmp_path / 'output'), + input_folder=str(input_folder), + concat_tokens=2048, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=False, + trust_remote_code=False, + ) + + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) args_str = 'Namespace(x = 5)' diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index a56778538c..fc0dc8a882 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -11,10 +11,10 @@ from composer import Trainer from composer.loggers import InMemoryLogger +from llmfoundry.command_utils import evaluate from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model -from llmfoundry.utils.config_utils import to_dict_container -from scripts.eval.eval import main # noqa: E402 +from llmfoundry.utils.config_utils import EVAL_CONFIG_KEYS, to_dict_container from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg @@ -75,7 +75,7 @@ def test_icl_eval( eval_cfg = copy.deepcopy(eval_cfg) eval_cfg.models[0].load_path = mock_saved_model_path assert isinstance(eval_cfg, om.DictConfig) - main(eval_cfg) + evaluate(eval_cfg) out, _ = capfd.readouterr() expected_results = '| Category | Benchmark | Subtask | Accuracy | Number few shot | Model |\n|:----------------------------|:---------------|:----------|-----------:|:------------------|:---------|\n| language_understanding_lite | lambada_openai | | 0 | 0-shot | tiny_mpt |' assert expected_results in out @@ -134,7 +134,15 @@ def test_loader_eval( test_cfg.eval_interval = '1ba' test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})}) - trainers, eval_gauntlet_df = main(test_cfg) + # This test uses a training yaml with training-only keys present. + # We exclude these keys before calling `evaluate` from the eval script. + allowed_keys = EVAL_CONFIG_KEYS + present_keys = set(test_cfg.keys()) + keys_to_pop = present_keys.difference(allowed_keys) + + [test_cfg.pop(key) for key in keys_to_pop] + + trainers, eval_gauntlet_df = evaluate(test_cfg) assert eval_gauntlet_df is None assert len(trainers) == 1 # one per model diff --git a/tests/a_scripts/eval/test_eval_inputs.py b/tests/a_scripts/eval/test_eval_inputs.py index 98b15743b3..86243ba154 100644 --- a/tests/a_scripts/eval/test_eval_inputs.py +++ b/tests/a_scripts/eval/test_eval_inputs.py @@ -2,14 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import copy import os -import warnings import omegaconf import pytest from omegaconf import DictConfig from omegaconf import OmegaConf as om -from scripts.eval.eval import main # noqa: E402 +from llmfoundry.command_utils import evaluate class TestHuggingFaceEvalYAMLInputs: @@ -42,12 +41,13 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: omegaconf.errors.InterpolationKeyError, omegaconf.errors.MissingMandatoryValue, TypeError, + ValueError, )): cfg[p + '-mispelled'] = cfg.pop(p) - main(cfg) + evaluate(cfg) cfg[p] = cfg.pop(p + '-mispelled') - def test_optional_mispelled_params_raise_warning( + def test_optional_mispelled_params_raise_error( self, cfg: DictConfig, ) -> None: @@ -67,15 +67,8 @@ def test_optional_mispelled_params_raise_warning( orig_value = cfg.pop(param, None) updated_param = param + '-mispelling' cfg[updated_param] = orig_value - with warnings.catch_warnings(record=True) as warning_list: - try: - main(cfg) - except: - pass - assert any( - f'Unused parameter {updated_param} found in cfg.' in - str(warning.message) for warning in warning_list - ) + with pytest.raises(ValueError): + evaluate(cfg) # restore configs. cfg = copy.deepcopy(old_cfg) @@ -112,4 +105,4 @@ def test_empty_load_path_raises_error(self, cfg: DictConfig) -> None: + ' Please check your yaml and the model_cfg to ensure that load_path is set.' cfg.models[0].load_path = None with pytest.raises(ValueError, match=error_string): - main(cfg) + evaluate(cfg) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 0577e13a1f..cd47b2df7c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import json import math import os @@ -382,6 +383,14 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' + mlflow_logger_mock._enabled = True + mlflow_logger_mock.run_url = 'fake-url' + checkpointer_callback.transform_model_pre_registration = MagicMock( + wraps=checkpointer_callback.transform_model_pre_registration, + ) + checkpointer_callback.pre_register_edit = MagicMock( + wraps=checkpointer_callback.pre_register_edit, + ) trainer = Trainer( model=original_model, device='gpu', @@ -405,9 +414,14 @@ def test_huggingface_conversion_callback_interval( task='llm/v1/completions', input_example=ANY, metadata={}, + pip_requirements=ANY, ) + assert checkpointer_callback.transform_model_pre_registration.call_count == 1 + assert checkpointer_callback.pre_register_edit.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: + assert checkpointer_callback.transform_model_pre_registration.call_count == 0 + assert checkpointer_callback.pre_register_edit.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 @@ -468,6 +482,7 @@ def _get_model_and_tokenizer( model: str, max_seq_len: int, tie_word_embeddings: bool, + precision: str, ): if model == 'mpt': model_cfg = { @@ -482,6 +497,7 @@ def _get_model_and_tokenizer( 'attn_config': { 'attn_impl': 'torch', }, + 'fc_type': 'te' if precision == 'amp_fp8' else 'torch', 'loss_fn': 'torch_crossentropy', 'tie_word_embeddings': tie_word_embeddings, } @@ -530,7 +546,7 @@ def _get_model_and_tokenizer( tokenizer_name = 'EleutherAI/gpt-neo-125M' elif model == 'llama2': assert tie_word_embeddings is None - if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: + if 'HF_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.', ) @@ -579,6 +595,7 @@ def _assert_mlflow_logger_calls( 'task': 'llm/v1/completions', 'input_example': default_input_example, 'metadata': {}, + 'pip_requirements': ANY, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 @@ -783,8 +800,9 @@ def _assert_checkpoint_equivalence( ) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize( - 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', - [('1ba', '1ba', '1ba', 1, 1)], + 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints,trainer_precision', + [('1ba', '1ba', '1ba', 1, 1, 'amp_bf16'), + ('1ba', '1ba', '1ba', 1, 1, 'amp_fp8')], ) @patch('os.cpu_count', MagicMock(return_value=1)) @patch( @@ -801,10 +819,30 @@ def test_huggingface_conversion_callback( max_duration: str, expected_hf_checkpoints: int, expected_normal_checkpoints: int, + trainer_precision: str, peft_config: Optional[dict], ): if model == 'mptmoe' and fsdp_state_dict_type is None: pytest.skip('mptmoe requires FSDP') + if trainer_precision == 'amp_fp8': + # Check if transformer-engine is installed for FP8. + try: + import transformer_engine.pytorch as te + except ImportError: + pytest.skip( + 'Precision amp_fp8 requires transformer-engine to be installed', + ) + + # Check we are using mpt models only for FP8. + if (model == 'neo' or model == 'llama2'): + pytest.skip( + 'Precision amp_fp8 works only for mpt models, not hf models', + ) + + # Check that we are using H100 or later for FP8. + if not (torch.cuda.get_device_capability() >= (8, 9)): + pytest.skip('Amp FP8 requires a GPU with compute capability >= 8.9') + delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -825,9 +863,10 @@ def test_huggingface_conversion_callback( # Get small version of each model model_cfg, tokenizer_name = _get_model_and_tokenizer( - model, - max_seq_len, - tie_word_embeddings, + model=model, + max_seq_len=max_seq_len, + tie_word_embeddings=tie_word_embeddings, + precision=trainer_precision, ) assert model_cfg is not None assert tokenizer_name is not None @@ -883,7 +922,7 @@ def test_huggingface_conversion_callback( trainer = Trainer( model=original_model, device='gpu', - precision='amp_bf16', + precision=trainer_precision, fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), @@ -900,24 +939,29 @@ def test_huggingface_conversion_callback( # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params( - trainer.state.model, - writeback=False, - recurse=True, - ): - _assert_checkpoint_equivalence( - tmp_path=tmp_path, - expected_normal_checkpoints=expected_normal_checkpoints, - expected_hf_checkpoints=expected_hf_checkpoints, - trainer=trainer, - batches_per_epoch=batches_per_epoch, - original_model=original_model, - precision=precision, - model=model, - tokenizer=tokenizer, - fsdp_state_dict_type=fsdp_state_dict_type, - peft_config=peft_config, - ) + + context_manager = te.onnx_export( # type: ignore + True, + ) if trainer_precision == 'amp_fp8' else contextlib.nullcontext() + with context_manager: + with FSDP.summon_full_params( + trainer.state.model, + writeback=False, + recurse=True, + ): + _assert_checkpoint_equivalence( + tmp_path=tmp_path, + expected_normal_checkpoints=expected_normal_checkpoints, + expected_hf_checkpoints=expected_hf_checkpoints, + trainer=trainer, + batches_per_epoch=batches_per_epoch, + original_model=original_model, + precision=precision, + model=model, + tokenizer=tokenizer, + fsdp_state_dict_type=fsdp_state_dict_type, + peft_config=peft_config, + ) dist.barrier() delete_transformers_cache() @@ -955,7 +999,7 @@ def test_convert_and_generate( om_cfg['model']['config_overrides']['hidden_size'] = 36 elif model == 'llama2': assert tie_word_embeddings is None - if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: + if 'HF_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.', ) @@ -1149,7 +1193,7 @@ def test_convert_and_generate_meta( @pytest.mark.world_size(4) @pytest.mark.gpu @pytest.mark.parametrize('num_experts', [2, 4, 8]) -@pytest.mark.parametrize('sharding_strategy', ['FULL_SHARD', 'HYBRID_SHARD']) +@pytest.mark.parametrize('sharding_strategy', ['FULL_SHARD']) def test_mptmoe_huggingface_conversion_callback( tmp_path: pathlib.Path, num_experts: int, @@ -1251,7 +1295,6 @@ def test_mptmoe_huggingface_conversion_callback( make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) dataloader_cfg = { - 'name': 'finetuning', 'dataset': { 'hf_name': tiny_dataset_folder_path, 'split': 'train', @@ -1320,7 +1363,6 @@ def test_mptmoe_huggingface_conversion_callback( save_weights_only=True, ) trainer.fit() - #self.state.outputs = self.state.model(self.state.batch) batch = trainer.state.batch model_output_logits = trainer.state.model(batch).logits @@ -1398,7 +1440,6 @@ def test_mptmoe_huggingface_conversion_callback( loaded_model_logits = loaded_model( input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), - prefix_mask=batch.get('bidirectional_mask', None), sequence_id=batch.get('sequence_id', None), inputs_embeds=batch.get('inputs_embeds', None), ).logits diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 2be1d5139d..1f724a6070 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -11,12 +11,12 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.command_utils import TrainConfig # noqa: E402 +from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, ) -from scripts.train.train import TrainConfig # noqa: E402 -from scripts.train.train import TRAIN_CONFIG_KEYS, main, validate_config from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg from tests.fixtures.autouse import REPO_DIR @@ -82,7 +82,7 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): test_cfg.max_duration = '1ba' test_cfg.eval_interval = '1ba' test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) - trainer = main(test_cfg) + trainer = train(test_cfg) assert isinstance(trainer.logger.destinations, tuple) @@ -126,7 +126,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): test_cfg.max_duration = '1ba' test_cfg.eval_interval = '1ba' test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) - trainer = main(test_cfg) + trainer = train(test_cfg) assert isinstance(trainer.logger.destinations, tuple) @@ -201,7 +201,7 @@ def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): test_cfg.eval_interval = '1ba' test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) test_cfg.model['use_train_metrics'] = False - trainer = main(test_cfg) + trainer = train(test_cfg) # Check eval metrics exist inmemorylogger = trainer.logger.destinations[ diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 5a3b21dc3b..73540afe2f 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -3,14 +3,13 @@ import copy import json import os -import warnings import omegaconf import pytest from omegaconf import DictConfig from omegaconf import OmegaConf as om -from scripts.train.train import main # noqa: E402 +from llmfoundry.command_utils import train def make_fake_index_file(path: str) -> None: @@ -63,8 +62,10 @@ def cfg(self, foundry_dir: str) -> DictConfig: def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None: """Check that mandatory misspelled inputs fail to train.""" cfg.trai_loader = cfg.pop('train_loader') - with pytest.raises((omegaconf.errors.MissingMandatoryValue, TypeError)): - main(cfg) + with pytest.raises( + (omegaconf.errors.MissingMandatoryValue, TypeError, ValueError), + ): + train(cfg) def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: """Check that missing mandatory parameters fail to train.""" @@ -86,10 +87,10 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: omegaconf.errors.InterpolationKeyError, omegaconf.errors.MissingMandatoryValue, )): - main(cfg) + train(cfg) cfg[param] = orig_param - def test_optional_misspelled_params_raise_warning( + def test_optional_misspelled_params_raise_error( self, cfg: DictConfig, ) -> None: @@ -113,15 +114,8 @@ def test_optional_misspelled_params_raise_warning( orig_value = cfg.pop(param, None) updated_param = param + '-misspelling' cfg[updated_param] = orig_value - with warnings.catch_warnings(record=True) as warning_list: - try: - main(cfg) - except: - pass - assert any( - f'Unused parameter {updated_param} found in cfg.' in - str(warning.message) for warning in warning_list - ) + with pytest.raises(ValueError): + train(cfg) # restore configs. cfg = copy.deepcopy(old_cfg) @@ -136,7 +130,7 @@ def test_extra_params_in_optimizer_cfg_errors( cfg.eval_loader.dataset.local = data_local cfg.optimizer.beta2 = 'extra-parameter' with pytest.raises(TypeError): - main(cfg) + train(cfg) def test_invalid_name_in_optimizer_cfg_errors( self, @@ -149,7 +143,7 @@ def test_invalid_name_in_optimizer_cfg_errors( cfg.train_loader.dataset.local = data_local cfg.eval_loader.dataset.local = data_local with pytest.raises(ValueError) as exception_info: - main(cfg) + train(cfg) assert str(exception_info.value).startswith( "Cant't find 'invalid-optimizer' in registry llmfoundry -> optimizers.", ) @@ -160,7 +154,7 @@ def test_extra_params_in_scheduler_cfg_errors( ) -> None: cfg.scheduler.t_warmup_extra = 'extra-parameter' with pytest.raises(TypeError): - main(cfg) + train(cfg) def test_invalid_name_in_scheduler_cfg_errors( self, @@ -168,7 +162,7 @@ def test_invalid_name_in_scheduler_cfg_errors( ) -> None: cfg.scheduler.name = 'invalid-scheduler' with pytest.raises(ValueError) as exception_info: - main(cfg) + train(cfg) assert str(exception_info.value).startswith( "Cant't find 'invalid-scheduler' in registry llmfoundry -> schedulers.", ) @@ -187,7 +181,7 @@ def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None: second_eval_loader.label = 'eval_1' cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) with pytest.raises(ValueError) as exception_info: - main(cfg) + train(cfg) assert str( exception_info.value, ) == 'When specifying multiple evaluation datasets, each one must include the \ diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index bbdbf3d691..075698a4c0 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -1,14 +1,283 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext +from typing import Any, Callable, Optional +from unittest.mock import MagicMock + +import pytest +from composer.core import State +from composer.core.time import Time, TimeUnit +from composer.devices import DeviceCPU +from composer.loggers import Logger +from omegaconf import OmegaConf as om +from torch.utils.data import DataLoader + +from llmfoundry.data.text_data import StreamingTextDataset from llmfoundry.utils.builders import build_callback -def test_curriculum_learning_callback_builds(): - kwargs = {'dataset_index': 0} +@pytest.mark.parametrize( + 'datamix,duration', + [ + (None, '1ep'), + ({ + 'dataset': 'some_dataset', + }, '1ep'), + (None, '10tok'), + (None, ''), + ({}, '1ep'), + ], +) +def test_curriculum_learning_callback_init( + datamix: Optional[dict[str, Any]], + duration: str, + tiny_ft_dataloader_cfg: dict[str, Any], +): + test_cfg = _get_test_cfg() + test_cfg['train_loader'] = tiny_ft_dataloader_cfg + train_loader = test_cfg['train_loader'] if datamix is None else datamix + kwargs = { + 'schedule': [{ + 'duration': duration, + 'train_loader': train_loader, + }, { + 'duration': '2ep', + 'train_loader': {}, + }], + } + if duration == '': + del kwargs['schedule'][0]['duration'] + if datamix is not None and len(datamix) == 0: + del kwargs['schedule'][0]['train_loader'] + + context = nullcontext() + if datamix is not None or duration == '': + context = pytest.raises(ValueError) + with context: + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + assert callback is not None + + +@pytest.mark.parametrize('duration', ['1ep', '10tok', '2ep']) +def test_curriculum_learning_callback_before_load( + duration: str, + build_tiny_mpt: Callable, +): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': duration, + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + context = nullcontext() + if duration != '1ep': + context = pytest.raises(ValueError) + with context: + callback.before_load(state, logger) + + +def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + assert state.timestamp.iteration == 0 + callback.after_load(state, logger) + assert state.timestamp.iteration == 1 + + +def test_curriculum_learning_callback_iteration( + build_tiny_mpt: Callable, + monkeypatch: pytest.MonkeyPatch, +): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + ds_mock = MagicMock(spec=StreamingTextDataset) + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: ds_mock, + ) + dl_mock.dataset = ds_mock + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + + callback.init(state, logger) + callback.iteration_start(state, logger) + assert state._iteration_length == Time(1, TimeUnit.EPOCH) + callback.iteration_end(state, logger) + callback.iteration_start(state, logger) + assert state._iteration_length == Time(2, TimeUnit.EPOCH) + + +def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + callback.iteration_start(state, logger) + callback.iteration_end(state, logger) + assert callback.state_dict() == { + 'schedule': kwargs['schedule'], + 'schedule_index': 1, + } + + +def test_curriculum_learning_callback_load_state_dict( + build_tiny_mpt: Callable, +): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + callback = build_callback( 'curriculum_learning', kwargs=kwargs, - train_config={'train_loader': {}}, + train_config=test_cfg, ) - assert callback is not None + callback.iteration_start(state, logger) + callback.iteration_end(state, logger) + assert callback.state_dict() == { + 'schedule': kwargs['schedule'], + 'schedule_index': 1, + } + + +def _get_test_cfg() -> dict[str, Any]: + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + with open(conf_path) as f: + test_cfg = om.load(f) + batch_size = test_cfg['device_train_microbatch_size' + ] # pyright: ignore [reportGeneralTypeIssues] + test_cfg['device_train_batch_size' + ] = batch_size # pyright: ignore [reportGeneralTypeIssues] + return om.to_container( + test_cfg, + resolve=True, + ) # pyright: ignore [reportGeneralTypeIssues] diff --git a/tests/callbacks/test_eval_output_logging_callback.py b/tests/callbacks/test_eval_output_logging_callback.py index 7778e39fe3..b5006f6fb2 100644 --- a/tests/callbacks/test_eval_output_logging_callback.py +++ b/tests/callbacks/test_eval_output_logging_callback.py @@ -1,13 +1,19 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import json +import re +from typing import Any +from unittest import mock +import pytest import torch import transformers from composer.core.state import State from composer.core.time import Timestamp from composer.loggers import InMemoryLogger, Logger +from composer.models import HuggingFaceModel from torch.utils.data import DataLoader from torchmetrics import Metric @@ -50,6 +56,23 @@ def update_curr_eval(self, dataloader: DataLoader, dataloader_label: str): self._dataloader_label = dataloader_label +class MockHFModel(HuggingFaceModel): + + def __init__(self, *args: Any, **kargs: Any): + pass + + +class RegexMatcher: + + def __init__(self, pattern: str): + self.pattern = re.compile(pattern) + + def __eq__(self, other: str): + if not isinstance(other, str): + return False + return bool(self.pattern.match(other)) + + def mock_lm_computation( metric: Metric, tokenizer: transformers.AutoTokenizer, @@ -158,8 +181,45 @@ def mock_mc_computation( metric.compute() +@pytest.mark.parametrize('is_hf_model', [True, False]) +@pytest.mark.parametrize('has_tokenizer', [True, False]) +@pytest.mark.parametrize('log_output_text', [True, False, None]) +def test_init( + is_hf_model: bool, + has_tokenizer: bool, + log_output_text: bool, +): + state = MockState() + in_memory_logger = InMemoryLogger() + logger = Logger(state, in_memory_logger) + + expected_error = log_output_text is True and not ( + is_hf_model and has_tokenizer + ) + exptected_log_output_text = ( + log_output_text is not False and is_hf_model and has_tokenizer + ) + + eval_output_logging = EvalOutputLogging( + loggers_to_use=['InMemoryLogger'], + log_output_text=log_output_text, + ) + + state = mock.Mock(model=MockHFModel() if is_hf_model else mock.Mock()) + state.dataloader.dataset = mock.Mock( + spec=['tokenizer'] if has_tokenizer else [], + ) + with pytest.raises( + ValueError, + ) if expected_error else contextlib.nullcontext(): + eval_output_logging.init(state, logger) + assert eval_output_logging.log_output_text == exptected_log_output_text + + +@pytest.mark.parametrize('log_output_text', [True, False]) def test_eval_output_logging_lm( tiny_gpt2_tokenizer: transformers.AutoTokenizer, + log_output_text: bool, ): # this test simulates an unrolled version of the eval loop occurring twice state = MockState() @@ -170,7 +230,11 @@ def test_eval_output_logging_lm( state.add_metric('lm_acc', lm_metric) # Construct the callback - eval_output_logging = EvalOutputLogging(loggers_to_use=['InMemoryLogger']) + eval_output_logging = EvalOutputLogging( + loggers_to_use=['InMemoryLogger'], + log_output_text=log_output_text, + ) + eval_output_logging.init(mock.Mock(model=MockHFModel()), logger) for _ in range(2): state.update_curr_eval( @@ -193,23 +257,28 @@ def test_eval_output_logging_lm( assert f'lm_acc_step_0' in in_memory_logger.tables # Only want one table - we log once to a single step value during eval_end() assert len(in_memory_logger.tables) == 1 - assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['columns'] == [ + logged_data = json.loads(in_memory_logger.tables[f'lm_acc_step_0']) + assert logged_data['columns'] == [ 'context', 'label', 'output', 'result', 'metric_name', + *(['outputs'] if log_output_text else []), 'input', 'run_name', ] + # We use the same data in each batch - assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['data'] == [ + assert logged_data['data'] == [ [ 'The dog is', ' furry', ' furry', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' dog is furry(\[PAD\])+I'),) + if log_output_text else []), 'The dog is furry', 'mock_name', ], @@ -219,6 +288,8 @@ def test_eval_output_logging_lm( '[PAD]', 0, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' love to eat(\[PAD\])+I'),) + if log_output_text else []), 'I love to eat pie', 'mock_name', ], @@ -228,6 +299,8 @@ def test_eval_output_logging_lm( ' long lines', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' hate long lines(\[PAD\])+The'),) + if log_output_text else []), 'I hate long lines', 'mock_name', ], @@ -237,6 +310,8 @@ def test_eval_output_logging_lm( ' snowy', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' weather is snowy(\[PAD\])+The'),) + if log_output_text else []), 'The weather is snowy', 'mock_name', ], @@ -246,6 +321,8 @@ def test_eval_output_logging_lm( ' furry', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' dog is furry(\[PAD\])+I'),) + if log_output_text else []), 'The dog is furry', 'mock_name', ], @@ -255,6 +332,8 @@ def test_eval_output_logging_lm( '[PAD]', 0, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' love to eat(\[PAD\])+I'),) + if log_output_text else []), 'I love to eat pie', 'mock_name', ], @@ -264,6 +343,8 @@ def test_eval_output_logging_lm( ' long lines', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' hate long lines(\[PAD\])+The'),) + if log_output_text else []), 'I hate long lines', 'mock_name', ], @@ -273,6 +354,8 @@ def test_eval_output_logging_lm( ' snowy', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' weather is snowy(\[PAD\])+The'),) + if log_output_text else []), 'The weather is snowy', 'mock_name', ], @@ -291,7 +374,11 @@ def test_eval_output_logging_mc( state.add_metric('mc_acc', mc_metric) # Construct the callback - eval_output_logging = EvalOutputLogging(loggers_to_use=['InMemoryLogger']) + eval_output_logging = EvalOutputLogging( + loggers_to_use=['InMemoryLogger'], + log_output_text=True, + ) + eval_output_logging.init(mock.Mock(model=MockHFModel()), logger) for _ in range(2): state.update_curr_eval( MockDataLoader(tiny_gpt2_tokenizer), @@ -314,7 +401,8 @@ def test_eval_output_logging_mc( assert f'mc_acc_step_0' in in_memory_logger.tables # Only want one table - we log once to a single step value during eval_end() assert len(in_memory_logger.tables) == 1 - assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['columns'] == [ + logged_data = json.loads(in_memory_logger.tables[f'mc_acc_step_0']) + assert logged_data['columns'] == [ 'context', 'correct_choice', 'correct_choice_idx', @@ -323,11 +411,12 @@ def test_eval_output_logging_mc( 'all_choices', 'result', 'metric_name', + 'outputs', 'input', 'run_name', ] # We use the same data for each batch - assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['data'] == [ + assert logged_data['data'] == [ [ 'Q: How do you cook a cake?', ' A: turn on the oven', @@ -340,6 +429,9 @@ def test_eval_output_logging_mc( ], 1, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: turn on the oven', 'mock_name', ], @@ -355,6 +447,9 @@ def test_eval_output_logging_mc( ], 0, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: do a backflip', 'mock_name', ], @@ -370,6 +465,9 @@ def test_eval_output_logging_mc( ], 1, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: turn on the oven', 'mock_name', ], @@ -385,6 +483,9 @@ def test_eval_output_logging_mc( ], 0, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: do a backflip', 'mock_name', ], diff --git a/tests/callbacks/test_loss_perp_v_len_callback.py b/tests/callbacks/test_loss_perp_v_len_callback.py index 46bde1c2f1..4c487560d2 100644 --- a/tests/callbacks/test_loss_perp_v_len_callback.py +++ b/tests/callbacks/test_loss_perp_v_len_callback.py @@ -1,5 +1,6 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Any from unittest.mock import MagicMock import pytest @@ -14,6 +15,7 @@ from omegaconf import OmegaConf as om from llmfoundry import registry +from llmfoundry.callbacks.loss_perp_v_len_callback import LossPerpVLen from llmfoundry.data.text_data import ( StreamingTextDataset, build_text_dataloader, @@ -172,3 +174,259 @@ def test_loss_perp_v_len_callback( ) / torch.sum(current_metric_dict['sum_length']) assert torch.allclose(loss, mean_loss_seq_id) assert torch.allclose(loss, mean_loss) + + +def test_metric(): + batch_size = 2 + seq_len = 100 + labels = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + logits = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + sequence_id = torch.tensor([[ + 0, + ] * 10 + [ + 1, + ] * 90, [ + 0, + ] * 50 + [ + 1, + ] * 50]) + loss = torch.rand([batch_size, seq_len]) + perplexity = torch.exp(loss) + + def mock_loss_fn(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss + + loss_v_len_metric = LossPerpVLen(ignore_index=-100) + loss_v_len_metric.update( + labels=labels, + logits=logits, + sequence_id=sequence_id, + loss_fn=mock_loss_fn, + ) + metric_dict = loss_v_len_metric.compute() + + assert torch.all(metric_dict['sum_length'] == 2 * torch.ones([100])) + assert torch.all( + metric_dict['sum_length_seq_id'] == torch.tensor([ + 4, + ] * 10 + [ + 3, + ] * 40 + [ + 1, + ] * 40 + [ + 0, + ] * 10), + ) + assert torch.all(metric_dict['mean_loss_v_len'] == torch.mean(loss, dim=0)) + assert torch.all( + metric_dict['mean_perplexity_v_len'] == torch.mean(perplexity, dim=0), + ) + + expected_mean_loss_seq_id_v_len_0 = ( + loss[0][:10] + loss[0][10:20] + loss[1][0:10] + loss[1][50:60] + ) / 4 + expected_mean_loss_seq_id_v_len_1 = ( + loss[0][20:60] + loss[1][10:50] + loss[1][60:100] + ) / 3 + expected_mean_loss_seq_id_v_len_2 = loss[0][60:100] + expected_mean_loss_seq_id_v_len_3 = -1 + + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][0:10] == + expected_mean_loss_seq_id_v_len_0, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][10:50] == + expected_mean_loss_seq_id_v_len_1, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][50:90] == + expected_mean_loss_seq_id_v_len_2, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][90:100] == + expected_mean_loss_seq_id_v_len_3, + ) + + expected_mean_perplexity_seq_id_v_len_0 = ( + perplexity[0][:10] + perplexity[0][10:20] + perplexity[1][0:10] + + perplexity[1][50:60] + ) / 4 + expected_mean_perplexity_seq_id_v_len_1 = ( + perplexity[0][20:60] + perplexity[1][10:50] + perplexity[1][60:100] + ) / 3 + expected_mean_perplexity_seq_id_v_len_2 = perplexity[0][60:100] + expected_mean_perplexity_seq_id_v_len_3 = -1 + + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][0:10] == + expected_mean_perplexity_seq_id_v_len_0, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][10:50] == + expected_mean_perplexity_seq_id_v_len_1, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][50:90] == + expected_mean_perplexity_seq_id_v_len_2, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][90:100] == + expected_mean_perplexity_seq_id_v_len_3, + ) + + +def test_valid_labels(): + batch_size = 1 + seq_len = 100 + ignore_labels_len = 10 + labels = torch.tensor([[ + 1, + ] * (seq_len - ignore_labels_len) + [ + -100, + ] * ignore_labels_len] * batch_size) + logits = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + sequence_id = torch.tensor([[ + 0, + ] * seq_len]) + loss = torch.rand([batch_size, seq_len]) + + def mock_loss_fn(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss + + loss_v_len_metric = LossPerpVLen(ignore_index=-100) + loss_v_len_metric.update( + labels=labels, + logits=logits, + sequence_id=sequence_id, + loss_fn=mock_loss_fn, + ) + metric_dict = loss_v_len_metric.compute() + assert torch.all(metric_dict['sum_length'][-ignore_labels_len:] == 0) + assert torch.all(metric_dict['sum_length_seq_id'][-ignore_labels_len:] == 0) + assert torch.all(metric_dict['mean_loss_v_len'][-ignore_labels_len:] == -1) + assert torch.all( + metric_dict['mean_perplexity_v_len'][-ignore_labels_len:] == -1, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][-ignore_labels_len:] == -1, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][-ignore_labels_len:] == -1, + ) + + +def test_padding(): + batch_size = 2 + seq_len = 100 + + labels_no_pad = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + logits_no_pad = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + sequence_id_no_pad = torch.tensor([[ + 0, + ] * 10 + [ + 1, + ] * 90, [ + 0, + ] * 50 + [ + 1, + ] * 50]) + loss_no_pad = torch.rand([batch_size, seq_len]) + + def mock_loss_fn_no_pad(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss_no_pad + + loss_v_len_metric_no_pad = LossPerpVLen(ignore_index=-100) + loss_v_len_metric_no_pad.update( + labels=labels_no_pad, + logits=logits_no_pad, + sequence_id=sequence_id_no_pad, + loss_fn=mock_loss_fn_no_pad, + ) + metric_dict_no_pad = loss_v_len_metric_no_pad.compute() + + pad_len = 10 + labels_pad = torch.tensor([[ + 1, + ] * seq_len + [ + -100, + ] * pad_len] * batch_size) + logits_pad = torch.tensor([[ + 1, + ] * (seq_len + pad_len)] * batch_size) + sequence_id_pad = torch.tensor([[ + 0, + ] * 10 + [ + 1, + ] * 90 + [ + -1, + ] * pad_len, [ + 0, + ] * 50 + [ + 1, + ] * 50 + [ + -1, + ] * pad_len]) + loss_pad = torch.cat([loss_no_pad, + torch.rand([batch_size, pad_len])], + dim=-1) + + def mock_loss_fn_pad(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss_pad + + loss_v_len_metric_pad = LossPerpVLen(ignore_index=-100) + loss_v_len_metric_pad.update( + labels=labels_pad, + logits=logits_pad, + sequence_id=sequence_id_pad, + loss_fn=mock_loss_fn_pad, + ) + metric_dict_pad = loss_v_len_metric_pad.compute() + + assert torch.all(metric_dict_pad['sum_length'][-pad_len:] == 0) + assert torch.all(metric_dict_pad['sum_length_seq_id'][-pad_len:] == 0) + assert torch.all(metric_dict_pad['mean_loss_v_len'][-pad_len:] == -1) + assert torch.all(metric_dict_pad['mean_perplexity_v_len'][-pad_len:] == -1) + assert torch.all(metric_dict_pad['mean_loss_seq_id_v_len'][-pad_len:] == -1) + assert torch.all( + metric_dict_pad['mean_perplexity_seq_id_v_len'][-pad_len:] == -1, + ) + + assert torch.all( + metric_dict_pad['sum_length'][:-pad_len] == + metric_dict_no_pad['sum_length'], + ) + assert torch.all( + metric_dict_pad['sum_length_seq_id'][:-pad_len] == + metric_dict_no_pad['sum_length_seq_id'], + ) + assert torch.all( + metric_dict_pad['mean_loss_v_len'][:-pad_len] == + metric_dict_no_pad['mean_loss_v_len'], + ) + assert torch.all( + metric_dict_pad['mean_perplexity_v_len'][:-pad_len] == + metric_dict_no_pad['mean_perplexity_v_len'], + ) + assert torch.all( + metric_dict_pad['mean_loss_seq_id_v_len'][:-pad_len] == + metric_dict_no_pad['mean_loss_seq_id_v_len'], + ) + assert torch.all( + metric_dict_pad['mean_perplexity_seq_id_v_len'][:-pad_len] == + metric_dict_no_pad['mean_perplexity_seq_id_v_len'], + ) diff --git a/tests/callbacks/test_system_metrics_monitor.py b/tests/callbacks/test_system_metrics_monitor.py new file mode 100644 index 0000000000..47095604eb --- /dev/null +++ b/tests/callbacks/test_system_metrics_monitor.py @@ -0,0 +1,15 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from composer.callbacks import SystemMetricsMonitor + +from llmfoundry.utils.builders import build_callback + + +def test_system_metrics_monitor_callback_builds(): + callback = build_callback( + 'system_metrics_monitor', + kwargs={}, + train_config={'train_loader': {}}, + ) + assert isinstance(callback, SystemMetricsMonitor) diff --git a/tests/data/test_data_encodings.py b/tests/data/test_data_encodings.py new file mode 100644 index 0000000000..a45bfbcb88 --- /dev/null +++ b/tests/data/test_data_encodings.py @@ -0,0 +1,205 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import pathlib + +import numpy as np +import pytest +import torch +from streaming import MDSWriter + +from llmfoundry.data import SUPPORTED_MDS_ENCODING_TYPES, StreamingTextDataset +from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset + + +@pytest.mark.parametrize( + 'token_encoding_type', + SUPPORTED_MDS_ENCODING_TYPES + ['default'], +) +@pytest.mark.parametrize('use_bytes', [True, False]) +@pytest.mark.parametrize('samples', [10]) +@pytest.mark.parametrize('max_seq_len', [2048]) +def test_encoding_types_text( + tmp_path: pathlib.Path, + token_encoding_type: str, + use_bytes: bool, + samples: int, + max_seq_len: int, +): + dataset_local_path = str(tmp_path) + if token_encoding_type != 'default': + encoding_dtype = getattr(np, token_encoding_type) + else: + encoding_dtype = None + + if use_bytes: + columns = { + 'tokens': 'bytes', + } + else: + columns = { + 'tokens': + 'ndarray:' + token_encoding_type + if token_encoding_type != 'default' else 'ndarray', + } + + with MDSWriter(out=dataset_local_path, columns=columns) as writer: + for _ in range(samples): + if token_encoding_type != 'default': + tokens = np.random.randint( + 0, + np.iinfo(encoding_dtype).max, + max_seq_len, + dtype=encoding_dtype, + ) + else: + tokens = np.random.randint( + 0, + 200, + max_seq_len, + ) + if use_bytes: + tokens = tokens.tobytes() + writer.write({'tokens': tokens}) + + if use_bytes and token_encoding_type != 'default': + dataset = StreamingTextDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + max_seq_len=max_seq_len, + local=dataset_local_path, + batch_size=1, + ) + else: + # There should be no need to pass in the token encoding type if writing out ndarrays, + # or if using the default token encoding type. + dataset = StreamingTextDataset( + tokenizer=None, + max_seq_len=max_seq_len, + local=dataset_local_path, + batch_size=1, + ) + + for _, sample in enumerate(dataset): + # StreamingTextDataset should return an int64 torch Tensor + assert sample.dtype == torch.int64 + assert sample.shape == (max_seq_len,) + + +@pytest.mark.parametrize( + 'token_encoding_type', + SUPPORTED_MDS_ENCODING_TYPES + ['default'], +) +@pytest.mark.parametrize('use_bytes', [True, False]) +@pytest.mark.parametrize('samples', [10]) +@pytest.mark.parametrize('max_seq_len', [2048]) +def test_encoding_types_finetuning( + tmp_path: pathlib.Path, + token_encoding_type: str, + use_bytes: bool, + samples: int, + max_seq_len: int, +): + dataset_local_path = str(tmp_path) + if token_encoding_type != 'default': + encoding_dtype = getattr(np, token_encoding_type) + else: + encoding_dtype = None + + if use_bytes: + columns = { + 'input_ids': 'bytes', + 'labels': 'bytes', + } + else: + columns = { + 'input_ids': + 'ndarray:' + token_encoding_type + if token_encoding_type != 'default' else 'ndarray', + 'labels': + 'ndarray:' + token_encoding_type + if token_encoding_type != 'default' else 'ndarray', + } + + with MDSWriter(out=dataset_local_path, columns=columns) as writer: + for _ in range(samples): + if token_encoding_type != 'default': + input_ids = np.random.randint( + 0, + np.iinfo(encoding_dtype).max, + max_seq_len, + dtype=encoding_dtype, + ) + labels = np.random.randint( + 0, + np.iinfo(encoding_dtype).max, + max_seq_len, + dtype=encoding_dtype, + ) + else: + input_ids = np.random.randint( + 0, + 200, + max_seq_len, + ) + labels = np.random.randint( + 0, + 200, + max_seq_len, + ) + if use_bytes: + input_ids = input_ids.tobytes() + labels = labels.tobytes() + writer.write({'input_ids': input_ids, 'labels': labels}) + + if use_bytes and token_encoding_type != 'default': + dataset = StreamingFinetuningDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + local=dataset_local_path, + max_seq_len=max_seq_len, + batch_size=1, + ) + else: + # There should be no need to pass in the token encoding type if writing out ndarrays, + # or if using the default token encoding type. + dataset = StreamingFinetuningDataset( + tokenizer=None, + local=dataset_local_path, + max_seq_len=max_seq_len, + batch_size=1, + ) + + for _, sample in enumerate(dataset): + # StreamingFinetuningDataset puts samples in a list, and converts arrays to lists too. + assert isinstance(sample['turns'][0]['input_ids'][0], int) + assert len(sample['turns'][0]['input_ids']) == max_seq_len + assert isinstance(sample['turns'][0]['labels'][0], int) + assert len(sample['turns'][0]['labels']) == max_seq_len + + +@pytest.mark.parametrize( + 'token_encoding_type', + ['int17', 'float32', 'complex', 'int4'], +) +@pytest.mark.parametrize('use_finetuning', [True, False]) +def test_unsupported_encoding_type( + token_encoding_type: str, + use_finetuning: bool, +): + with pytest.raises(ValueError, match='The token_encoding_type*'): + if use_finetuning: + StreamingFinetuningDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + local='dataset/path', + max_seq_len=2048, + batch_size=1, + ) + else: + StreamingTextDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + max_seq_len=2048, + local='dataset/path', + batch_size=1, + ) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index f567aeb3ba..a9ec0d8a62 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -5,10 +5,9 @@ import pathlib import random import shutil -from argparse import Namespace from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import ContextManager, Literal, Optional, Union +from typing import Any, Callable, ContextManager, Dict, Literal, Optional, Union from unittest.mock import MagicMock, patch import catalogue @@ -22,6 +21,9 @@ from streaming import MDSWriter from streaming.base.util import clean_stale_shared_memory +from llmfoundry.command_utils import convert_dataset_hf +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \ + get_columns_and_format from llmfoundry.data import build_dataloader, build_finetuning_dataloader from llmfoundry.data.finetuning.collator import ( _HF_IGNORE_INDEX, @@ -29,6 +31,7 @@ ) from llmfoundry.data.finetuning.tasks import ( DOWNLOADED_FT_DATASETS_DIRPATH, + HUGGINGFACE_FOLDER_EXTENSIONS, SUPPORTED_EXTENSIONS, dataset_constructor, is_valid_ift_example, @@ -54,9 +57,6 @@ NotEnoughDatasetSamplesError, UnknownExampleTypeError, ) -# yapf: enable -from scripts.data_prep.convert_dataset_hf import main as main_hf -from scripts.data_prep.convert_finetuning_dataset import get_columns_and_format from tests.data_utils import ( make_tiny_conversation_ft_dataset, make_tiny_ft_dataset, @@ -114,8 +114,8 @@ def build_mock_ft_streaming_dataset( columns = {'input_ids': 'bytes', 'labels': 'bytes'} else: columns = { - 'input_ids': 'ndarray:uint32', - 'labels': 'ndarray:uint32', + 'input_ids': 'ndarray:int32', + 'labels': 'ndarray:int32', } else: columns = {'prompt': 'str', 'response': 'str'} @@ -142,7 +142,7 @@ def build_mock_ft_streaming_dataset( else: sample_to_write[key] = np.asarray( sample[key], - dtype=np.uint32, + dtype=np.int32, ) output_writer.write(sample_to_write) else: @@ -203,42 +203,34 @@ def test_correct_padding( path = get_abs_data_path(data_local) shutil.rmtree(path, ignore_errors=True) if pretokenize: - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [split], - 'out_root': path, - 'compression': None, - 'concat_tokens': 2048, - 'tokenizer': tokenizer_name, - 'tokenizer_kwargs': {}, - 'bos_text': bos_text, - 'eos_text': eos_text, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=[split], + out_root=path, + compression=None, + concat_tokens=2048, + tokenizer=tokenizer_name, + tokenizer_kwargs={}, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=False, + num_workers=None, ) else: - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [split], - 'out_root': path, - 'compression': None, - 'concat_tokens': None, - 'tokenizer': tokenizer_name, - 'tokenizer_kwargs': {}, - 'bos_text': bos_text, - 'eos_text': eos_text, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=[split], + out_root=path, + compression=None, + concat_tokens=None, + tokenizer=tokenizer_name, + tokenizer_kwargs={}, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=False, + num_workers=None, ) if not os.path.isdir(path): raise RuntimeError(f'c4 dataset at {path} not set up as expected') @@ -471,14 +463,15 @@ def test_finetuning_dataloader_safe_load( ) # If no raised errors, we should expect downloaded files with only safe file types. - if expectation == does_not_raise(): + if isinstance(expectation, does_not_raise): download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name) downloaded_files = [ file for _, _, files in os.walk(download_dir) for file in files ] assert len(downloaded_files) > 0 assert all( - Path(file).suffix in SUPPORTED_EXTENSIONS + Path(file).suffix in SUPPORTED_EXTENSIONS + + HUGGINGFACE_FOLDER_EXTENSIONS or file == '.gitignore' for file in downloaded_files ) @@ -1228,6 +1221,21 @@ def test_token_counting_func_dataloader_setting( 'timeout': 0, } + def build_from_hf( + self, # type: ignore + dataset_name: str, + split: str, + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable] = None, + tokenizer: transformers.PreTrainedTokenizerBase = None, + target_prompts: str = 'last', + target_responses: str = 'none', + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, + ): + return [] + if dataloader_type == 'finetuning-hf': cfg = DictConfig({ 'dataset': { @@ -1243,8 +1251,7 @@ def test_token_counting_func_dataloader_setting( }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', - lambda *args, - **kwargs: [], + build_from_hf, ) dl = build_finetuning_dataloader( tokenizer=gptt, diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index b910b8c5ff..d181dbde0b 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.utils.builders import build_tokenizer @@ -206,6 +207,15 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): if batch_ix >= 3: break + assert isinstance(loader, DataLoader) + assert isinstance(loader.dataset, StreamingFinetuningDataset) + assert loader.dataset.packing_ratio is not None + assert isinstance(loader.batch_size, int) + assert loader.dataset.packing_ratio == int(loader.batch_size / 6) + + state_dict = loader.dataset.state_dict(num_samples=2, from_beginning=False) + assert state_dict['sample_in_epoch'] == 2 * loader.dataset.packing_ratio + @pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) @patch( diff --git a/tests/data_utils.py b/tests/data_utils.py index 9653d8579a..ea64943735 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -4,16 +4,16 @@ import json import os import shutil -from argparse import Namespace from pathlib import Path from typing import Dict, List, Optional from omegaconf import DictConfig from omegaconf import OmegaConf as om -from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 -from scripts.data_prep.convert_dataset_json import \ - main as main_json # noqa: E402 +from llmfoundry.command_utils import ( + convert_dataset_hf, + convert_dataset_json, +) def make_tiny_ft_dataset( @@ -230,23 +230,19 @@ def create_c4_dataset_xxsmall(path: Path) -> str: downloaded_split = 'val_xxsmall' # very fast to convert # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [downloaded_split], - 'out_root': c4_dir, - 'compression': None, - 'concat_tokens': 2048, - 'tokenizer': 'EleutherAI/gpt-neox-20b', - 'tokenizer_kwargs': {}, - 'bos_text': '', - 'eos_text': '<|endoftext|>', - 'no_wrap': False, - 'num_workers': 8, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=[downloaded_split], + out_root=c4_dir, + compression=None, + concat_tokens=2048, + tokenizer='EleutherAI/gpt-neox-20b', + tokenizer_kwargs={}, + bos_text='', + eos_text='<|endoftext|>', + no_wrap=False, + num_workers=8, ) # copy the small downloaded_split to other c4 splits for mocking purposes @@ -269,20 +265,16 @@ def create_arxiv_dataset(path: Path) -> str: if not os.getcwd().endswith('scripts'): arxiv_path = os.path.join('scripts', arxiv_path) - main_json( - Namespace( - **{ - 'path': arxiv_path, - 'out_root': arxiv_dir, - 'compression': None, - 'split': downloaded_split, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_json( + path=arxiv_path, + out_root=arxiv_dir, + compression=None, + split=downloaded_split, + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, ) return arxiv_dir diff --git a/tests/eval/test_in_context_learning_datasets.py b/tests/eval/test_in_context_learning_datasets.py index a3c3e88364..81769a18e6 100644 --- a/tests/eval/test_in_context_learning_datasets.py +++ b/tests/eval/test_in_context_learning_datasets.py @@ -37,6 +37,7 @@ InContextLearningLMAccuracy, InContextLearningMultipleChoiceAccuracy, ) +from llmfoundry.utils.builders import build_icl_evaluators def test_strip_data(): @@ -1090,15 +1091,22 @@ def test_mc_task_dataloader_subcategories( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=2, - prompt_string= - 'The following are multiple choice questions (with answers).\n', - example_delimiter='\n', - continuation_delimiter='Answer: ', - destination_path=str(tmp_path / 'icl.jsonl'), has_categories=True, + destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'num_fewshot': + 2, + 'max_seq_len': + seqlen, + 'pad_tok_id': + tokenizer.eos_token_id, + 'prompt_string': + 'The following are multiple choice questions (with answers).\n', + 'example_delimiter': + '\n', + 'continuation_delimiter': + 'Answer: ', + }, ) assert isinstance(dls, dict) @@ -1142,13 +1150,15 @@ def test_lm_task_dataloader_extra_space( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=10, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 10, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1192,13 +1202,15 @@ def test_lm_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1241,14 +1253,16 @@ def test_schema_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - question_prelimiter=prelimiter, - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'prelimiter': prelimiter, + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) @@ -1300,13 +1314,15 @@ def test_schema_task_dataloader_sentpiece_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) @@ -1358,13 +1374,15 @@ def test_lm_task_dataloader_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1410,13 +1428,15 @@ def test_mc_task_dataloader_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1473,13 +1493,15 @@ def test_mc_split_batch( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1550,13 +1572,15 @@ def test_qa_split_batch( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=8, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) # pyright @@ -1612,14 +1636,16 @@ def test_qa_task_dataloader_w_null_eos( dataset_uri, tokenizer, batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': '\nA:', + }, ) @@ -1647,14 +1673,16 @@ def test_qa_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': '\nA:', + }, ) assert isinstance(dl, DataSpec) @@ -1714,15 +1742,17 @@ def test_qa_task_with_cot_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter="\nA: Let's think step by step. ", - cot_delimiter=' #### ', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': "\nA: Let's think step by step. ", + 'cot_delimiter': ' #### ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1779,14 +1809,16 @@ def test_mc_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - question_prelimiter=prelimiter, - example_delimiter=example_delimiter, - continuation_delimiter='\nA: ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'prelimiter': prelimiter, + 'example_delimiter': example_delimiter, + 'continuation_delimiter': '\nA: ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1851,13 +1883,15 @@ def test_lm_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) evaluator = Evaluator( @@ -1903,13 +1937,15 @@ def test_schema_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -1968,14 +2004,16 @@ def test_mc_task_evaluation_subcategories( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), has_categories=True, + kwargs={ + 'max_seq_len': max_seq_len, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dls, dict) @@ -2039,13 +2077,15 @@ def test_mc_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=64, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 64, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2107,13 +2147,15 @@ def test_qa_task_evaluation_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2168,14 +2210,16 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter="A: Let's think step by step. ", - cot_delimiter=' #### ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': "A: Let's think step by step. ", + 'cot_delimiter': ' #### ', + }, ) evaluator = Evaluator( @@ -2228,13 +2272,15 @@ def test_qa_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2288,14 +2334,16 @@ def test_qa_task_with_cot_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter="A: Let's think step by step", - cot_delimiter=' #### ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': "A: Let's think step by step", + 'cot_delimiter': ' #### ', + }, ) evaluator = Evaluator( @@ -2339,13 +2387,15 @@ def test_lm_spacing_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' UNIQUE ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' UNIQUE ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2409,15 +2459,17 @@ def test_hf_dataloading_lm_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2490,16 +2542,18 @@ def test_hf_dataloading_custom_parsing( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Orbs: ', - continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Orbs: ', + 'continuation_delimiter': '\nSpell:', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2535,3 +2589,42 @@ def test_hf_dataloading_custom_parsing( ) assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:') assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:') + + +@pytest.mark.parametrize( + 'prelimiter_key_name', + ['prelimiter', 'question_prelimiter'], +) +def test_bc_question_prelimiter( + mpt_tokenizer: transformers.PreTrainedTokenizerBase, + prelimiter_key_name: str, +): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + dataset_uri = f'{local_data}/piqa_small.jsonl' + + icl_tasks = [ + { + 'dataset_uri': dataset_uri, + 'label': 'piqa', + 'icl_task_type': 'multiple_choice', + 'max_seq_len': 64, + 'pad_tok_id': mpt_tokenizer.eos_token_id, + 'num_fewshot': [0], + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + prelimiter_key_name: 'This is a question: ', + }, + ] + + evaluators, _ = build_icl_evaluators( + icl_tasks=icl_tasks, + tokenizer=mpt_tokenizer, + default_batch_size=2, + default_max_seq_len=128, + ) + + assert len(evaluators) == 1 + evaluator = evaluators[0] + assert evaluator.dataloader.dataloader.dataset.prelimiter == 'This is a question: ' # type: ignore diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index ff437974bf..2c34dff817 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +from typing import Any from unittest.mock import MagicMock, patch from composer.utils import dist @@ -26,14 +27,11 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: @fixture -@patch('os.cpu_count', MagicMock(return_value=1)) -def tiny_ft_dataloader( +def tiny_ft_dataloader_cfg( tiny_ft_dataset_path: Path, - mpt_tokenizer: PreTrainedTokenizerBase, max_seq_len: int = 128, - device_batch_size: int = 1, -) -> DataLoader: - dataloader_cfg = DictConfig({ +) -> dict[str, Any]: + return { 'dataset': { 'hf_name': str(tiny_ft_dataset_path), 'split': 'train', @@ -49,7 +47,17 @@ def tiny_ft_dataloader( 'prefetch_factor': 2, 'persistent_workers': False, 'timeout': 0, - }) + } + + +@fixture +@patch('os.cpu_count', MagicMock(return_value=1)) +def tiny_ft_dataloader( + mpt_tokenizer: PreTrainedTokenizerBase, + tiny_ft_dataloader_cfg: dict[str, Any], + device_batch_size: int = 1, +) -> DataLoader: + dataloader_cfg = DictConfig(tiny_ft_dataloader_cfg) dataloader = build_finetuning_dataloader( **dataloader_cfg, diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index 1ca384171d..d0ec544de8 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -172,7 +172,7 @@ def test_hf_config_override( @pytest.mark.skipif( - 'HUGGING_FACE_HUB_TOKEN' not in os.environ, + 'HF_TOKEN' not in os.environ, reason='CI does not have access to llama2', ) def test_rope_scaling_override(): @@ -205,7 +205,7 @@ def test_rope_scaling_override(): @pytest.mark.skipif( - 'HUGGING_FACE_HUB_TOKEN' not in os.environ, + 'HF_TOKEN' not in os.environ, reason='CI does not have access to Dbrx', ) def test_nested_override(): diff --git a/tests/models/hf/test_hf_transform.py b/tests/models/hf/test_hf_transform.py new file mode 100644 index 0000000000..f479b50f73 --- /dev/null +++ b/tests/models/hf/test_hf_transform.py @@ -0,0 +1,76 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +import pytest +from composer.models.huggingface import maybe_get_underlying_model +from peft import PeftConfig, PeftModel +from transformers import LlamaForCausalLM, PreTrainedModel + +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM +from llmfoundry.models.utils import init_empty_weights + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'peft_config', + [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'r': 2, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }, + ], +) +def test_hf_transform(peft_config: Optional[dict]): + model_cfg = { + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + 'pretrained': False, + 'peft_config': peft_config, + 'init_device': 'meta', + 'tokenizer': 'codellama/CodeLlama-7b-hf', + } + + class TransformedHFCausalLM(ComposerHFCausalLM): + + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + assert isinstance(model, LlamaForCausalLM) + with init_empty_weights(): + model.config.num_hidden_layers = 1 + new_model = type(model)(model.config) + return new_model + + def get_peft_config( + self, + peft_config_dict: Dict[str, Any], + ) -> PeftConfig: + peft_config_dict['target_modules'] = ['o_proj'] + return super().get_peft_config(peft_config_dict) + + composer_model = TransformedHFCausalLM(**model_cfg) + model = composer_model.model + inner_model = maybe_get_underlying_model(model) + + if peft_config: + peft_model = composer_model.model + assert isinstance(peft_model, PeftModel) + + target_modules = peft_model.peft_config[peft_model.active_adapter + ].target_modules + assert list(target_modules) == ['o_proj'] + + assert isinstance(inner_model, LlamaForCausalLM) + assert inner_model.config.num_hidden_layers == 1 diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py new file mode 100644 index 0000000000..bdffe2b49f --- /dev/null +++ b/tests/models/layers/test_attention.py @@ -0,0 +1,160 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from llmfoundry.models.layers.layer_builders import build_attention_layer + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_unfused_wqkv(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + attn_config_fused = generic_attn_kwargs.copy() + attn_config_fused['fused_qkv'] = True + + attn_config_unfused = generic_attn_kwargs.copy() + attn_config_unfused['fused_qkv'] = False + + attn_fused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_fused, + ) + attn_unfused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_unfused, + ) + + # Make sure unfused attention has the same params as the fused one. + fused_wqkv = attn_fused.Wqkv.weight.detach().clone() + kv_heads_len = (fused_wqkv.shape[0] - dim) // 2 + Wq_shape_before = (attn_unfused.Wq.weight.shape, attn_unfused.Wq.bias.shape) + Wk_shape_before = (attn_unfused.Wk.weight.shape, attn_unfused.Wk.bias.shape) + Wv_shape_before = (attn_unfused.Wv.weight.shape, attn_unfused.Wv.bias.shape) + + attn_unfused.Wq.weight.data = fused_wqkv[:dim, :] + attn_unfused.Wk.weight.data = fused_wqkv[dim:dim + kv_heads_len, :] + attn_unfused.Wv.weight.data = fused_wqkv[dim + kv_heads_len:, :] + attn_unfused.out_proj.weight.data = attn_fused.out_proj.weight + attn_unfused.Wq.bias.data = attn_fused.Wqkv.bias[:dim] + attn_unfused.Wk.bias.data = attn_fused.Wqkv.bias[dim:dim + kv_heads_len] + attn_unfused.Wv.bias.data = attn_fused.Wqkv.bias[dim + kv_heads_len:] + attn_unfused.out_proj.bias.data = attn_fused.out_proj.bias + + # Make sure initialization fuse splits are as expected. + all_fuse_splits = ( + 0, + [i * d_head for i in range(1, n_heads + 2 * kv_n_heads)], + ) + q_fuse_splits = (0, [i * d_head for i in range(1, n_heads)]) + kv_fuse_splits = (0, [i * d_head for i in range(1, kv_n_heads)]) + + assert attn_fused.Wqkv._fused == all_fuse_splits + assert attn_unfused.Wq._fused == q_fuse_splits + assert attn_unfused.Wk._fused == kv_fuse_splits + assert attn_unfused.Wv._fused == kv_fuse_splits + + assert torch.allclose( + attn_fused.Wqkv.weight, + torch.cat( + [ + attn_unfused.Wq.weight, + attn_unfused.Wk.weight, + attn_unfused.Wv.weight, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.Wqkv.bias, + torch.cat( + [ + attn_unfused.Wq.bias, + attn_unfused.Wk.bias, + attn_unfused.Wv.bias, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.out_proj.weight, + attn_unfused.out_proj.weight, + ) + assert torch.allclose(attn_fused.out_proj.bias, attn_unfused.out_proj.bias) + + assert Wq_shape_before == ( + attn_unfused.Wq.weight.shape, + attn_unfused.Wq.bias.shape, + ) + assert Wk_shape_before == ( + attn_unfused.Wk.weight.shape, + attn_unfused.Wk.bias.shape, + ) + assert Wv_shape_before == ( + attn_unfused.Wv.weight.shape, + attn_unfused.Wv.bias.shape, + ) + + x1 = torch.randn(1, 1, dim) + x2 = x1.detach().clone() + x1.requires_grad = True + x2.requires_grad = True + + out_fused, _, _ = attn_fused(x1) + out_unfused, _, _ = attn_unfused(x2) + + assert torch.allclose(out_fused, out_unfused) + + # Dummy loss function is simply the sum. + loss_fused = out_fused.sum() + loss_fused.backward() + + loss_unfused = out_unfused.sum() + loss_unfused.backward() + + assert isinstance(x1.grad, torch.Tensor) + assert isinstance(x2.grad, torch.Tensor) + assert torch.allclose(x1.grad, x2.grad) + combined_grad = torch.concat( + [ + attn_unfused.Wq.weight.grad, + attn_unfused.Wk.weight.grad, + attn_unfused.Wv.weight.grad, + ], + dim=0, + ) + assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor) + assert isinstance(combined_grad, torch.Tensor) + assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad) diff --git a/tests/models/layers/test_blocks.py b/tests/models/layers/test_blocks.py new file mode 100644 index 0000000000..fb6608152a --- /dev/null +++ b/tests/models/layers/test_blocks.py @@ -0,0 +1,69 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional +from unittest.mock import MagicMock + +import pytest +import torch + +from llmfoundry.models.layers import blocks +from llmfoundry.models.layers.blocks import MPTBlock + + +def test_default_attention_mask_slicing(): + attention_mask = torch.tensor([1, 1, 0, 1]).byte() + assert isinstance(attention_mask, torch.ByteTensor) + + block = MPTBlock( + d_model=4, + n_heads=1, + expansion_ratio=1, + ) + + output_mask = block.slice_attention_mask( + attention_mask=attention_mask, + seq_len=4, + ) + + assert torch.equal(output_mask, attention_mask) + + +def test_attention_mask_slicing_called(monkeypatch: pytest.MonkeyPatch): + m = torch.randn(2, 4, 4) + attention_mask = torch.tensor([1, 1, 1, 1]).byte() + dummy_return_mask = torch.tensor([1, 1, 1, 0]).byte() + assert isinstance(attention_mask, torch.ByteTensor) + assert isinstance(dummy_return_mask, torch.ByteTensor) + indices = torch.arange(4) + + unpad_mock = MagicMock(return_value=(m, indices, None, None)) + pad_mock = MagicMock(return_value=m) + monkeypatch.setattr(blocks, 'unpad_input', unpad_mock) + monkeypatch.setattr(blocks, 'pad_input', pad_mock) + + class MPTBlockTest(MPTBlock): + + def slice_attention_mask( + self, + attention_mask: Optional[torch.ByteTensor], + seq_len: int, + ) -> Optional[torch.ByteTensor]: + del seq_len + del attention_mask + return dummy_return_mask # type: ignore + + block = MPTBlockTest( + d_model=4, + n_heads=1, + expansion_ratio=1, + use_pad_tok_in_ffn=False, + ) + + block.apply_ffn( + attention_mask=attention_mask, + m=m, + ) + + assert unpad_mock.call_count == 1 + unpad_mock.assert_called_with(m, dummy_return_mask) diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py new file mode 100644 index 0000000000..bb78763f58 --- /dev/null +++ b/tests/models/layers/test_ffn.py @@ -0,0 +1,73 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +from llmfoundry.models.layers.ffn import quickgelu_activation +from llmfoundry.models.layers.layer_builders import build_ffn + + +@pytest.mark.gpu +def test_quickgelu_activation(): + d_model = 32 + expansion_ratio = 1 + no_bias = True + ffn_config = { + 'ffn_act_fn': { + 'name': 'quick_gelu', + }, + 'ffn_type': 'mptmlp', + } + rank: int = dist.get_rank() + device_str = f'cuda:{rank}' + device: torch.device = torch.device(device_str) + + ffn1 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + assert ( + ffn1.act == quickgelu_activation + ), f'Expected quick_gelu activation function, got {ffn1.act}' + + ffn_config = { + 'ffn_act_fn': { + 'name': 'gelu', + }, + 'ffn_type': 'mptmlp', + } + ffn2 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + + def num_params(model: nn.Module) -> int: + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([p.numel() for p in model_parameters]) + + ffn1_numparams = num_params(ffn1) + ffn2_numparams = num_params(ffn2) + assert ( + ffn1_numparams == ffn2_numparams + ), 'Only activation paths should have changed, re-check modeling!' + + input_ = torch.rand(1, d_model, device=device) + output1 = ffn1(input_) + output2 = ffn2(input_) + assert ( + output1.numel() == output2.numel() + ), 'Only activation paths should have changed, re-check modeling!' + assert ( + not torch.allclose(output1, output2) + ), 'Functions are different, outputs should not match!' diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 669a6a93a1..4bfdfb84dc 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens @@ -537,3 +538,213 @@ def test_grouped_query_invalid_heads(): with pytest.raises(ValueError, match=expected_error): _ = attention.GroupedQueryAttention(**cfg) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }], +) +def test_reuse_prev_layer_kv_cache( + pos_emb_config: dict, + device: str = 'cuda', +): + """Checks reusing previous layer's kv cache.""" + alibi = pos_emb_config['alibi'] + rope = pos_emb_config['rope'] + + cfg = { + 'attn_impl': 'flash', + 'd_model': 64, + 'n_heads': 4, + 'attn_pdrop': 0, + 'clip_qkv': True, + } + + n, s, f = 2, 4, cfg['d_model'] + assert cfg['d_model'] % cfg['n_heads'] == 0 + cfg['kv_n_heads'] = 2 + + sequence_id = torch.LongTensor([ + [0] * 2 + [1] * (s - 2), + [0] * 4 + [1] * (s - 4), + ]).to(device=device) + + # Computes its own kv cache + cfg['reuse_kv_layer_idx'] = None + attn0 = build_attention_layer( + name='grouped_query_attention', + attn_kwargs=cfg, # type: ignore + ).to(device) + + # Reuses layer 0's kv cache + cfg['reuse_kv_layer_idx'] = 0 + attn1 = build_attention_layer( + name='grouped_query_attention', + attn_kwargs=cfg, # type: ignore + ).to(device) + attn0_sd = attn0.state_dict() + attn0_sd['Wq.weight'] = attn0_sd['Wqkv.weight'][:cfg['d_model']] + attn0_sd['Wq.bias'] = attn0_sd['Wqkv.bias'][:cfg['d_model']] + del attn0_sd['Wqkv.weight'] + del attn0_sd['Wqkv.bias'] + attn1.load_state_dict(attn0_sd) + + attention_mask = torch.ones(n, s).to(device).bool() + + def gen_bias(attn_impl: str): + causal = True + attn_bias = None + bs = attention.attn_bias_shape( + attn_impl, + cfg['n_heads'], + s, + alibi, + use_sequence_id=True, + causal=causal, + ) + if bs is not None: + attn_bias = torch.zeros(*bs, device=device) + attn_bias = attention.build_attn_bias( + attn_impl, + attn_bias, + cfg['n_heads'], + s, + causal=causal, + alibi=alibi, + alibi_bias_max=8, + ) + + return attn_bias + + attention_mask_in_length = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=True, + attn_impl='flash', + attention_mask=attention_mask, + ) + + flash_attn_padding_info = gen_flash_attn_padding_info( + n, + s, + 0, + torch.device(device), + attention_mask_in_length, + attention_mask, + ) + + x0 = torch.randn(n, s, f).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + + with torch.autocast(x0.device.type): + attn_bias_0 = gen_bias('flash') + alibi_slopes_0 = None + if alibi: + alibi_slopes_0 = gen_slopes( + n_heads=cfg['n_heads'], + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + rotary_emb_w_meta_info = None + if rope: + rotary_embedding = gen_rotary_embedding( + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), + max_seq_len=s, + d_model=cfg['d_model'], + n_heads=cfg['n_heads'], + ).to(device) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + rotary_emb_w_meta_info = { + 'impl': + pos_emb_config['rope_impl'], + 'rotary_emb': + rotary_embedding, + 'offset_info': + pos if (pos_emb_config['rope_impl'] == 'hf') else 0, + 'seq_len': + s, + } + + y0, _, prev_layer_key_value = attn0( + x0, + past_key_value=(), + attn_bias=attn_bias_0, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + alibi_slopes=alibi_slopes_0, + ) + attn_bias_1 = gen_bias('flash') + alibi_slopes_1 = None + if alibi: + alibi_slopes_1 = gen_slopes( + n_heads=cfg['n_heads'], + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + + prev_layer_key_value = [ + t.clone().detach() for t in prev_layer_key_value + ] + y1, _, _ = attn1( + x1, + past_key_value=None, + attn_bias=attn_bias_1, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + alibi_slopes=alibi_slopes_1, + prev_layer_key_value=prev_layer_key_value, + ) + y0 *= attention_mask.unsqueeze(-1) + y1 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + assert allclose_helper(y0, y1) + + torch_name_param_map = dict(attn1.named_parameters()) + for n, p in attn0.named_parameters(): + if 'Wq' in n: + tp = torch_name_param_map[n.replace('Wqkv', 'Wq')] + assert p.grad is not None + assert tp.grad is not None + assert allclose_helper(p[:cfg['d_model']], tp) + assert allclose_helper(p.grad[:cfg['d_model']], tp.grad) + else: + tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None + assert allclose_helper(p, tp) + assert allclose_helper(p.grad, tp.grad) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index a62a7dd114..ed40e7a88a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -5,6 +5,7 @@ import os import pathlib import warnings +from functools import partial from typing import Any, Dict, List, Optional, Union, cast from unittest import mock @@ -13,10 +14,15 @@ import torch.nn as nn from accelerate import init_empty_weights from composer.core.precision import Precision, get_precision_context +from composer.distributed.dist_strategy import prepare_fsdp_module from composer.models.huggingface import maybe_get_underlying_model from composer.optim import DecoupledAdamW -from composer.trainer.dist_strategy import prepare_fsdp_module -from composer.utils import dist, get_device, reproducibility +from composer.utils import ( + FSDPConfig, + dist, + get_device, + reproducibility, +) from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import ( @@ -29,6 +35,7 @@ ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms @@ -39,7 +46,8 @@ is_flash_v2_installed, ) from llmfoundry.models.layers.blocks import MPTBlock -from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM +from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel +from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -67,12 +75,17 @@ def _load_tokenizer_cfg(cfg: Union[Dict[str, Any], DictConfig]) -> Dict: def _get_objs( request: pytest.FixtureRequest, conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', + model_config_overrides: Optional[Dict] = None, + attn_impl: str = 'torch', ): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property', ) test_cfg = get_config(conf_path=conf_path) + if model_config_overrides is not None: + for k, v in model_config_overrides.items(): + test_cfg.model[k] = v # Read FSDP Config as a dict fsdp_config = test_cfg.get('fsdp_config', None) @@ -92,7 +105,7 @@ def _get_objs( device = 'cuda' if is_gpu else 'cpu' test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { - 'attn_impl': 'torch', + 'attn_impl': attn_impl, } test_cfg.model.init_device = device test_cfg.device = device @@ -2538,7 +2551,14 @@ def test_hf_init( betas=(0.9, 0.99), ) - prepare_fsdp_module(model, optimizer, fsdp_config, precision, device, False) + prepare_fsdp_module( + model, + optimizer, + FSDPConfig(**fsdp_config), + precision, + device, + False, + ) model = HuggingFaceModelWithFSDP(model, tokenizer) @@ -2605,3 +2625,319 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): output = model(batch) assert not torch.isnan(output.logits).any() + + +def test_construct_blocks(): + n_layers = 13 + + config = MPTConfig( + d_model=32, + n_heads=16, + n_layers=n_layers, + expansion_ratio=2, + max_seq_len=64, + attn_config={ + 'attn_impl': 'flash', + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 4, + }, + ) + + # override architecture taken from https://research.character.ai/optimizing-inference/ + config.block_overrides = {} + config.block_overrides['overrides'] = { + 'reuse_kv_layer': { + 'attn_config': { + 'reuse_kv_layer_idx': -6, + }, + }, + 'sliding_window_layer': { + 'attn_config': { + 'sliding_window_size': 1024, + }, + }, + 'sliding_window_layer_reuse': { + 'attn_config': { + 'sliding_window_size': 1024, + 'reuse_kv_layer_idx': -1, + }, + }, + } + config.block_overrides['order'] = [ + { + 'name': 'default', + }, + { + 'order': [ + { + 'name': 'sliding_window_layer', + }, + { + 'name': 'sliding_window_layer_reuse', + }, + { + 'name': 'sliding_window_layer', + }, + { + 'name': 'sliding_window_layer_reuse', + 'repeat': 2, + }, + { + 'name': 'reuse_kv_layer', + }, + ], + 'repeat': 2, + }, + ] + + block_list = MPTModel(config).construct_blocks(config) + + assert len(block_list) == n_layers + assert block_list[0].attn.sliding_window_size == -1 + assert block_list[0].attn.reuse_kv_layer_idx is None + + for layer_offset in [1, 7]: + assert block_list[layer_offset].attn.sliding_window_size == 1024 + assert block_list[layer_offset].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 1].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 1].attn.reuse_kv_layer_idx == layer_offset + + assert block_list[layer_offset + 2].attn.sliding_window_size == 1024 + assert block_list[layer_offset + 2].attn.reuse_kv_layer_idx is None + assert block_list[layer_offset + 3].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 3].attn.reuse_kv_layer_idx == layer_offset + 2 + assert block_list[layer_offset + 4].attn.sliding_window_size == 1024 + assert block_list[layer_offset + + 4].attn.reuse_kv_layer_idx == layer_offset + 2 + + assert block_list[layer_offset + 5].attn.sliding_window_size == -1 + assert block_list[layer_offset + 5].attn.reuse_kv_layer_idx == 0 + + +@pytest.mark.gpu +def test_reuse_prev_layer_kv_cache( + request: pytest.FixtureRequest, + batch_size: int = 2, +): + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + model_config_overrides = { + 'block_overrides': { + 'order': [ + { + 'name': 'default', + }, + { + 'name': 'kv_reuse_layer', + }, + ], + 'overrides': { + 'kv_reuse_layer': { + 'attn_config': { + 'reuse_kv_layer_idx': -1, + }, + }, + }, + }, + 'use_cache': True, + } + test_cfg, model, _ = _get_objs( + request=request, + conf_path=conf_path, + model_config_overrides=model_config_overrides, + attn_impl='flash', + ) + + batch = gen_random_batch(batch_size, test_cfg) + + assert batch['input_ids'].shape == torch.Size([ + batch_size, + test_cfg.max_seq_len, + ]) + model.train() + + prev_layer_key_value_dict = {} + + def mock_forward(b_forward, b_idx, *args, **kwargs): # type: ignore + if 'prev_layer_key_value' in kwargs: + prev_layer_key_value_dict[b_idx] = kwargs['prev_layer_key_value'] + return b_forward(*args, **kwargs) + + for b_idx, block in enumerate(model.model.transformer.blocks): + block.forward = partial(mock_forward, block.forward, b_idx) + + with get_precision_context(test_cfg.precision): + outputs = model(batch) + assert len(outputs.past_key_values) == 2 + assert torch.all( + outputs.past_key_values[0][0] == outputs.past_key_values[1][0], + ) + assert torch.all( + outputs.past_key_values[0][1] == outputs.past_key_values[1][1], + ) + assert 0 not in prev_layer_key_value_dict + assert torch.all( + prev_layer_key_value_dict[1][0] == outputs.past_key_values[0][0], + ) + assert torch.all( + prev_layer_key_value_dict[1][1] == outputs.past_key_values[0][1], + ) + + +def test_override_block_args(): + block_args = {'a': 1, 'b': {'c': 3}, 'd': 4} + override_config = {'a': 2, 'b': {'c': 5}, 'e': 6} + allowed_block_overrides = {'a': None, 'b': {'c': None}, 'e': None} + new_config = MPTModel._override_block_args( + block_args, + override_config, + allowed_block_overrides, + ) + assert new_config['a'] == 2 + assert new_config['d'] == 4 + assert new_config['e'] == 6 + assert new_config['b']['c'] == 5 + + +def test_get_modules_order_expanded(): + order = [ + { + 'name': 'default', + }, + { + 'name': 'layer_a', + 'repeat': 2, + }, + { + 'order': [{ + 'name': 'layer_b', + },], + 'repeat': 3, + }, + { + 'name': 'layer_c', + 'repeat': 2, + }, + { + 'name': 'default', + }, + ] + expected_list = [ + 'default', + 'layer_a', + 'layer_a', + 'layer_b', + 'layer_b', + 'layer_b', + 'layer_c', + 'layer_c', + 'default', + ] + assert expected_list == MPTModel._get_modules_order_expanded(order) + + +@pytest.mark.parametrize('reuse_kv_layer_idx', [-2, -1, 0]) +def test_resolve_reuse_kv_layer_idx(reuse_kv_layer_idx: int): + layer_a_override = { + 'key_1': 'value_a', + 'attn_config': { + 'key_2': 'value_b', + }, + } + layer_b_override = { + 'key_1': 'value_c', + 'attn_config': { + 'key_2': 'value_d', + }, + } + layer_c_override = { + 'key_1': 'value_c' if reuse_kv_layer_idx == -1 else 'value_a', + 'attn_config': { + 'key_2': 'value_d' if reuse_kv_layer_idx == -1 else 'value_b', + 'reuse_kv_layer_idx': reuse_kv_layer_idx, + }, + } + block_overrides = { + 'overrides': { + 'layer_a': layer_a_override, + 'layer_b': layer_b_override, + 'layer_c': layer_c_override, + }, + } + model_modules_order_expanded = ['layer_a', 'layer_b', 'layer_c'] + if reuse_kv_layer_idx == -1: + model_modules_order_expanded = [ + 'layer_a', + 'layer_b', + 'layer_c', + 'layer_c', + 'layer_c', + 'layer_a', + 'layer_c', + ] + reuse_kv_layer_idx_dict = {} + + def _validate_helper(b_idx: int) -> int: + return MPTModel._resolve_reuse_kv_layer_idx( + overrides_definition=block_overrides['overrides'], + model_modules_order_expanded=model_modules_order_expanded, + b_idx=b_idx, + override_config=copy.deepcopy( + block_overrides['overrides'][model_modules_order_expanded[b_idx] + ], + ), + reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, + ) + + if reuse_kv_layer_idx == -1: + assert _validate_helper(b_idx=2) == 1 + assert _validate_helper(b_idx=3) == 1 + assert _validate_helper(b_idx=4) == 1 + with pytest.raises( + expected_exception=ValueError, + match= + 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer\.', # type: ignore + ): + _validate_helper(b_idx=6) + + elif reuse_kv_layer_idx == -2: + assert _validate_helper(b_idx=2) == 0 + else: + with pytest.raises( + expected_exception=ValueError, + match= + 'The relative index of kv layer to reuse, override_attn_config\[\"reuse_kv_layer_idx\"\]=0, should be negative\.', # type: ignore + ): + _validate_helper(b_idx=2) + + +def test_hf_rotary_child_class_builds(): + rope_head_dim = 32 + num_heads = 4 + max_seq_len = 128 + rope_theta = 10000 + bsz = 4 + value = torch.rand([bsz, num_heads, max_seq_len, rope_head_dim]) + position_ids = torch.Tensor([ + list(range(max_seq_len)), + ] * bsz) + + rot_emb_mp = LlamaRotaryEmbeddingFoundry( + rope_head_dim, + max_seq_len, + rope_theta, + device='cpu', + ) + cos_mp, sin_mp = rot_emb_mp(value, position_ids) + + rot_emb = LlamaRotaryEmbedding( + rope_head_dim, + max_seq_len, + rope_theta, + device='cpu', + ) + cos, sin = rot_emb(value, position_ids) + + assert torch.all(cos == cos_mp) + assert torch.all(sin == sin_mp) diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 6a41e64f48..34fb23f670 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } dail_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=dail_rope_config['rope_impl'], rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], rope_hf_config={}, max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') dail_rope_w_meta_info = { 'impl': 'dail', @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } hf_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=hf_rope_config['rope_impl'], rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, rope_hf_config=hf_rope_config['rope_hf_config'], max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py new file mode 100644 index 0000000000..484ac2b23a --- /dev/null +++ b/tests/models/test_rope_scaling.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + +rope_config = { + 'rope_theta': 500000.0, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'factor': 8.0, + 'low_freq_factor': 1.0, + 'high_freq_factor': 4.0, + 'original_max_position_embeddings': 8192, + 'type': 'llama3', + }, +} + +rope_dail_config = {} + + +def test_rope_scaling(): + d_model = 128 + n_heads = 32 + max_seq_len = 65536 + + embedding = gen_rotary_embedding( + d_model=d_model, + n_heads=n_heads, + rope_dail_config=rope_dail_config, + max_seq_len=max_seq_len, + **rope_config, + ) + + assert isinstance(embedding, LlamaRotaryEmbedding) diff --git a/tests/models/utils/test_config_moe_args.py b/tests/models/utils/test_config_moe_args.py new file mode 100644 index 0000000000..426363d2c3 --- /dev/null +++ b/tests/models/utils/test_config_moe_args.py @@ -0,0 +1,30 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import pytest + +from llmfoundry.models.utils.config_moe_args import ( + config_megablocks_moe_args, + get_megablocks_device_mesh, +) + + +@pytest.mark.gpu +def test_config_megablocks_moe_args_error(): + ffn_config_base: dict[str, Any] = { + 'moe_world_size': 1, + 'lbl_process_group': 'not_real', + 'ffn_type': 'mb_moe', + 'fc_type': 'torch', + } + + with pytest.raises(ValueError): + config_megablocks_moe_args( + ffn_config=ffn_config_base, + d_model=128, + expansion_ratio=4, + n_layers=2, + get_device_mesh=get_megablocks_device_mesh, + ) diff --git a/tests/test_registry.py b/tests/test_registry.py index 87881450d4..c4d1a1bcd5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -24,6 +24,7 @@ def test_expected_registries_exist(): 'loggers', 'optimizers', 'schedulers', + 'tokenizers', 'callbacks', 'algorithms', 'callbacks_with_config', @@ -42,6 +43,10 @@ def test_expected_registries_exist(): 'attention_classes', 'attention_implementations', 'fcs', + 'icl_datasets', + 'config_transforms', + 'load_planners', + 'save_planners', } assert existing_registries == expected_registry_names diff --git a/tests/test_utils.py b/tests/test_utils.py index 89517e64ff..dc9bcd9baf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,18 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List +from typing import Any, Dict, List +import catalogue import pytest +from omegaconf import DictConfig + +from llmfoundry.registry import config_transforms +from llmfoundry.utils.config_utils import ( + TRAIN_CONFIG_KEYS, + TrainConfig, + make_dataclass_and_log_config, +) def generate_exclusive_test_params(param_names: List[str]): @@ -25,3 +34,39 @@ def generate_exclusive_test_params(param_names: List[str]): param_values = list(params.values()) param_id = f'{name}=True' yield pytest.param(*param_values, id=param_id) + + +def test_config_transforms(): + config = DictConfig({ + 'global_train_batch_size': 1, + 'device_train_microbatch_size': 1, + 'model': {}, + 'scheduler': {}, + 'max_seq_len': 128, + 'train_loader': {}, + 'max_duration': 1, + 'tokenizer': {}, + 'eval_interval': 1, + 'seed': 1, + 'optimizer': {}, + 'variables': {}, + },) + + def dummy_transform(config: Dict[str, Any]) -> Dict[str, Any]: + config['variables']['fake_key'] = 'fake_value' + return config + + config_transforms.register('dummy_transform', func=dummy_transform) + + _, parsed_config = make_dataclass_and_log_config( + config, + TrainConfig, + TRAIN_CONFIG_KEYS, + transforms='all', + ) + + assert isinstance(parsed_config.variables, Dict) + assert parsed_config.variables['fake_key'] == 'fake_value' + + del catalogue.REGISTRY[ + ('llmfoundry', 'config_transforms', 'dummy_transform')] diff --git a/tests/tokenizers/test_registry.py b/tests/tokenizers/test_registry.py new file mode 100644 index 0000000000..920c207a64 --- /dev/null +++ b/tests/tokenizers/test_registry.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +from transformers import PreTrainedTokenizer + +from llmfoundry.registry import tokenizers +from llmfoundry.utils import build_tokenizer + + +class DummyTokenizer(PreTrainedTokenizer): + """A dummy tokenizer that inherits from ``PreTrainedTokenizer``.""" + + def __init__( + self, + model_name: Optional[str] = 'dummy', + **kwargs: Optional[Dict[str, Any]], + ): + """Dummy constructor that has no real purpose.""" + super().__init__( + model_name=model_name, + eos_token='0', + pad_token='1', + **kwargs, + ) + + def get_vocab(self) -> Dict[str, int]: + return {} + + +def test_tokenizer_registry(): + tokenizers.register('dummy', func=DummyTokenizer) + tokenizer = build_tokenizer(tokenizer_name='dummy', tokenizer_kwargs={}) + assert type(tokenizer) == DummyTokenizer diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index dfcb5b327c..fb6cb0c5df 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -13,17 +13,24 @@ from composer.callbacks import Generate from composer.core import Evaluator from composer.loggers import WandBLogger +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, +) from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer +from llmfoundry.registry import load_planners, save_planners from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import ( add_metrics_to_eval_loaders, build_callback, build_eval_loaders, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_tokenizer, ) @@ -345,6 +352,34 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): assert eval_loaders2[1].metric_names == [] +def test_build_load_planner(): + # Dummy LoadPlanner for testing + class DummyLoadPlanner(DefaultLoadPlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + load_planners.register('dummy', func=DummyLoadPlanner) + load_planner = build_load_planner('dummy', is_test=True) + + assert isinstance(load_planner, DummyLoadPlanner) + assert load_planner.is_test is True + + +def test_build_save_planner(): + # Dummy SavePlanner for testing + class DummySavePlanner(DefaultSavePlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + save_planners.register('dummy', func=DummySavePlanner) + save_planner = build_save_planner('dummy', is_test=True) + + assert isinstance(save_planner, DummySavePlanner) + assert save_planner.is_test is True + + def test_add_metrics_to_eval_loaders(): evaluators = [ Evaluator( diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py new file mode 100644 index 0000000000..1b78d0077b --- /dev/null +++ b/tests/utils/test_config_utils.py @@ -0,0 +1,15 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.utils.config_utils import update_config_with_batch_size_info + + +def test_update_config_with_batch_size_info(): + config = {} + config = update_config_with_batch_size_info(config, 1, 2, 3) + + assert config['n_gpus'] == 1 + assert config['device_train_batch_size'] == 1 + assert config['device_train_microbatch_size'] == 2 + assert config['device_train_grad_accum'] == 3 + assert config['device_eval_batch_size'] == 2