diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index ef85332047..062aa41bf4 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -19,7 +19,7 @@ defaults: working-directory: . jobs: code-quality: - runs-on: ubuntu-20.04 + runs-on: linux-ubuntu-latest timeout-minutes: 30 strategy: matrix: diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index fc511d7e60..cf3581f716 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -8,7 +8,7 @@ on: jobs: coverage: timeout-minutes: 5 - runs-on: ubuntu-latest + runs-on: linux-ubuntu-latest steps: - name: Checkout Repo uses: actions/checkout@v3 diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index e4e6f83551..0bb0b4087a 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -19,18 +19,11 @@ jobs: include: - name: "2.3.1_cu121" base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 - dep_groups: "[gpu]" + dep_groups: "[all]" - name: "2.3.1_cu121_aws" base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws - dep_groups: "[gpu]" + dep_groups: "[all]" steps: - - name: Maximize Build Space on Worker - uses: easimon/maximize-build-space@v4 - with: - overprovision-lvm: true - remove-dotnet: true - remove-android: true - remove-haskell: true - name: Checkout uses: actions/checkout@v3 @@ -47,6 +40,13 @@ jobs: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_PASSWORD }} + - name: Login to GHCR + uses: docker/login-action@v2 + with: + username: ${{ secrets.GHCR_USERNAME }} + password: ${{ secrets.GHCR_TOKEN }} + registry: ghcr.io + - name: Calculate Docker Image Variables run: | set -euxo pipefail @@ -60,13 +60,17 @@ jobs: if [ "${{ github.event_name }}" == "pull_request" ]; then echo "Triggered by pull_request event." STAGING_REPO="mosaicml/ci-staging" - IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + GHCR_STAGING_REPO="ghcr.io/databricks-mosaic/ci-staging" + GHCR_IMAGE_TAG="${GHCR_STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_IMAGE_TAG}" IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache" else # Triggered by push or workflow_dispatch event echo "Triggered by ${{ github.event_name }} event, releasing to prod" PROD_REPO="mosaicml/llm-foundry" - IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" + GHCR_PROD_REPO="ghcr.io/databricks-mosaic/llm-foundry" + GHCR_IMAGE_TAG="${GHCR_PROD_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_PROD_REPO}:${{matrix.name}}-latest" + IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest,${GHCR_IMAGE_TAG}" IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" fi diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 2dd1c0edab..2c85719756 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -15,23 +15,28 @@ concurrency: cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: pytest-cpu: - uses: mosaicml/ci-testing/.github/workflows/pytest-cpu.yaml@v0.0.9 + name: ${{ matrix.name }} + runs-on: ubuntu-latest strategy: matrix: include: - name: "cpu-2.3.1" + pip_deps: "[all-cpu]" container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 markers: "not gpu" pytest_command: "coverage run -m pytest" - name: ${{ matrix.name }} - if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - name: ${{ matrix.name }} - pip_deps: "[all-cpu]" - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - safe_directory: llm-foundry + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Run PR CPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.1.0 + with: + name: ${{ matrix.name }} + container: ${{ matrix.container }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + safe_directory: llm-foundry coverage: uses: ./.github/workflows/coverage.yaml name: Coverage Results diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index c5638e403d..ba1a4f9ba4 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -9,12 +9,15 @@ on: - main - release/** workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: pytest-gpu-1: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.9 + name: ${{ matrix.name }} + if: github.repository_owner == 'mosaicml' + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: @@ -22,24 +25,28 @@ jobs: - name: "gpu-2.3.1-1" container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" - pytest_command: "coverage run -m pytest" pip_deps: "[all]" + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 1 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} + pytest-gpu-2: name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - gpu_num: 1 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} - pytest-gpu-2: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.9 + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: @@ -47,24 +54,28 @@ jobs: - name: "gpu-2.3.1-2" container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" - pytest_command: "coverage run -m pytest" pip_deps: "[all]" + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 2 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} + pytest-gpu-4: name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - gpu_num: 2 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} - pytest-gpu-4: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.9 + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: @@ -72,19 +83,21 @@ jobs: - name: "gpu-2.3.1-4" container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" - pytest_command: "coverage run -m pytest" pip_deps: "[all]" - name: ${{ matrix.name }} - if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - gpu_num: 4 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 4 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 144e3f1ad3..c09f9bb7a5 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -14,7 +14,7 @@ jobs: name: Build and Publish llm-foundry PyPI Package needs: - code-quality - runs-on: ubuntu-latest + runs-on: linux-ubuntu-latest steps: - name: Checkout source uses: actions/checkout@v3 diff --git a/.github/workflows/smoketest.yaml b/.github/workflows/smoketest.yaml index 2163111710..d38849cddc 100644 --- a/.github/workflows/smoketest.yaml +++ b/.github/workflows/smoketest.yaml @@ -18,7 +18,7 @@ defaults: working-directory: . jobs: smoketest: - runs-on: ubuntu-20.04 + runs-on: linux-ubuntu-latest timeout-minutes: 20 strategy: matrix: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc2e3f55cd..b45021dd8c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,17 +77,6 @@ repos: hooks: - id: docformatter args: [--in-place, --wrap-summaries=80, --wrap-descriptions=80] -- repo: https://github.com/PyCQA/pydocstyle - hooks: - - id: pydocstyle - name: pydocstyle - entry: pydocstyle - language: python - types: [python] - exclude: (.ci|.github) - additional_dependencies: - - toml - rev: 6.1.1 - repo: https://github.com/adrienverge/yamllint.git rev: v1.28.0 hooks: diff --git a/Dockerfile b/Dockerfile index 9366d7dbcd..cee7063cdd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,8 @@ FROM $BASE_IMAGE ARG BRANCH_NAME ARG DEP_GROUPS +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" + # Check for changes in setup.py. # If there are changes, the docker cache is invalidated and a fresh pip installation is triggered. ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py setup.py diff --git a/README.md b/README.md index 0299e43710..e8a6708c5a 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ DBRX is a state-of-the-art open source LLM trained by Databricks Mosaic team. It | DBRX Base | 32768 | https://huggingface.co/databricks/dbrx-base | | DBRX Instruct | 32768 | https://huggingface.co/databricks/dbrx-instruct | -Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). +Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/blob/main/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). For more information about the DBRX models, see https://github.com/databricks/dbrx. @@ -309,10 +309,15 @@ dependencies = [ "llm-foundry", ] +# Note: Even though in python code, this would be llmfoundry.registry.loggers, +# when specified in the entry_points, it has to be "llmfoundry_loggers". That is, +# the segments of the name should be joined by an _ in the entry_points section. [project.entry-points."llmfoundry_loggers"] my_logger = "foundry_registry.loggers:MyLogger" ``` +If developing new components via entrypoints, it is important to note that Python entrypoints are global to the Python environment. This means that if you have multiple packages that register components with the same key, the last one installed will be the one used. This can be useful for overriding components in LLM Foundry, but can also lead to unexpected behavior if not careful. Additionally, if you change the pyproject.toml, you will need to reinstall the package for the changes to take effect. You can do this quickly by installing with `pip install -e . --no-deps` to avoid reinstalling dependencies. + ### Direct call to register You can also register a component directly in your code: diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 8dbd180c0a..b851aaa559 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -50,6 +50,7 @@ tokenizers, utils, ) +from llmfoundry._version import __version__ from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric from llmfoundry.models.hf import ComposerHFCausalLM @@ -63,6 +64,7 @@ from llmfoundry.optim import DecoupledLionW __all__ = [ + '__version__', 'StreamingFinetuningDataset', 'StreamingTextDataset', 'InContextLearningDataset', @@ -87,5 +89,3 @@ 'tokenizers', 'utils', ] - -__version__ = '0.11.0.dev0' diff --git a/llmfoundry/_version.py b/llmfoundry/_version.py new file mode 100644 index 0000000000..4c11746b43 --- /dev/null +++ b/llmfoundry/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""The LLM Foundry Version.""" + +__version__ = '0.11.0.dev' diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 646d86c8d3..1b3c31e861 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -557,7 +557,8 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: installation_path = i['path'] if not found_llm_foundry: - from llmfoundry import __version__ as latest_foundry_version + from llmfoundry._version import \ + __version__ as latest_foundry_version # If github integration is not found, foundry is likely installed # through the run command. In this case, we'll add the integration diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 98a672f8db..449ab338bc 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -128,18 +128,17 @@ def after_load(self, state: State, logger: Logger): self._validate_dataloader(state.train_dataloader) # If checkpoint was saved before iteration was incremented, we need to increment it now + duration = self._schedule[self._schedule_index]['duration'] if (( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.TOKEN and state.timestamp.token_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.TOKEN and + state.timestamp.token_in_iteration >= duration.value ) or ( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.EPOCH and + state.timestamp.epoch_in_iteration >= duration.value )): log.warning(( - 'The CurriculumLearning callback has detected that the previous run did not correctly ' - 'increment the iteration.' + 'The CurriculumLearning callback has detected that the ' + 'previous run did not correctly increment the iteration.' )) self._schedule_index += 1 state.timestamp = state.timestamp.to_next_iteration() @@ -199,24 +198,13 @@ def load_state_dict(self, state: dict[str, Any]): f'Expected {saved_loader} but got {current_loader}', )) - # Ensure that the current datamix duration is greater than timestamp + # Ensure that the current datamix duration is in the correct units duration = self._schedule[self._schedule_index]['duration'] if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: raise ValueError(( f'Duration must be in terms of tokens or epochs, but got ', f'{duration.unit}.', )) - if (( - duration.unit == TimeUnit.TOKEN and - duration > state['timestamp'].token_in_iteration - ) or ( - duration.unit == TimeUnit.EPOCH and - duration > state['timestamp'].epoch_in_iteration - )): - raise ValueError(( - 'The duration of the current datamix must be less or equal to ' - 'than the saved timestamp.' - )) def _build_train_loader( self, diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4de7f9f2c6..79dc73de98 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn from composer.core import Callback, Event, Precision, State, Time, TimeUnit -from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -29,7 +28,12 @@ ) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information -from packaging import version +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import ( PretrainedConfig, PreTrainedModel, @@ -179,6 +183,7 @@ def __init__( 'bfloat16': torch.bfloat16, }[precision] self.flatten_imports = flatten_imports + self.using_peft = False # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name @@ -212,6 +217,14 @@ def __init__( ) self.mlflow_logging_config = mlflow_logging_config + if 'metadata' in self.mlflow_logging_config: + self.pretrained_model_name = self.mlflow_logging_config[ + 'metadata'].get( + 'pretrained_model_name', + None, + ) + else: + self.pretrained_model_name = None self.huggingface_folder_name_fstr = os.path.join( 'huggingface', @@ -274,6 +287,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) + + # Check if the model is using PEFT + if state.is_model_ddp: + composer_model = state.model.module + elif isinstance(state.model.model, FSDP): + composer_model = state.model + else: + composer_model = state.model + self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. timeout = 3600 @@ -362,6 +384,34 @@ def transform_config( copied_config.ffn_config['moe_world_size'] = 1 return copied_config + def pre_register_edit(self, local_save_path: str): + """Edit the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + local_save_path (str): The path to the model to be transformed. + """ + pass + + def transform_model_pre_registration( + self, + model: PreTrainedModel, + ) -> PreTrainedModel: + """Transform the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -388,85 +438,63 @@ def _save_checkpoint(self, state: State, logger: Logger): temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: - composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model state_dict_model = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): - composer_model = state.model original_model: PreTrainedModel = state.model.model.module state_dict_model = state.model.model original_tokenizer = state.model.tokenizer else: - composer_model = state.model original_model: PreTrainedModel = state.model.model state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - if version.parse(torch.__version__) > version.parse('2.2.9'): - from torch.distributed._tensor import DTensor - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - ) - cpu_offload = True - - # Add a dtensor->cpu tensor hook to avoid CUDA OOM - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - hooks = [] - for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append( - module. - _register_state_dict_hook(dtensor_to_tensor_hook), - ) - - state_dict = get_model_state_dict( - state_dict_model, - options=StateDictOptions( - full_state_dict=True, - cpu_offload=cpu_offload, - ), - ) - for hook in hooks: - hook.remove() - else: - state_dict_context = fsdp_state_dict_type_context( - original_model, - state_dict_type='full', - ) if ((not state.is_model_ddp) and - isinstance(state_dict_model, - FSDP)) else contextlib.nullcontext() - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # Convert the state dict to the requested precis - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) + cpu_offload = True + + # Add hook to move tensors to cpu to avoid CUDA OOM + def tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + # Offload any DTensors to CPU + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + else: + state_dict[fqn] = None + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) + del tensor + if dist.get_global_rank() != 0: + state_dict = {} + return state_dict + + hooks = [] + for _, module in state_dict_model.named_modules(): + hooks.append(module._register_state_dict_hook(tensor_hook),) + + state_dict = get_model_state_dict( + state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + ), + ) + for hook in hooks: + hook.remove() new_model_instance = None # Need this for pyright because variable could be unbound @@ -480,22 +508,19 @@ def dtensor_to_tensor_hook( log.debug(f'Creating new model instance') - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(new_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter], - ) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + if self.using_peft: + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(new_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter], + ) + else: new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( **original_model.generation_config.to_dict(), @@ -512,6 +537,16 @@ def dtensor_to_tensor_hook( original_tokenizer, ) + # Ensure that the pretrained model name is correctly set on the saved HF checkpoint. + if self.pretrained_model_name is not None: + new_model_instance.name_or_path = self.pretrained_model_name + if self.using_peft: + new_model_instance.base_model.name_or_path = self.pretrained_model_name + for k in new_model_instance.peft_config.keys(): + new_model_instance.peft_config[ + k + ].base_model_name_or_path = self.pretrained_model_name + log.debug('Saving Hugging Face checkpoint to disk') # This context manager casts the TE extra state in io.BytesIO format to tensor format # Needed for proper hf ckpt saving. @@ -529,7 +564,7 @@ def dtensor_to_tensor_hook( original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': + if new_model_instance.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility( temp_save_dir, @@ -556,6 +591,11 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): + + new_model_instance = self.transform_model_pre_registration( + new_model_instance, + ) + components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer @@ -575,7 +615,7 @@ def dtensor_to_tensor_hook( model_saving_kwargs: Dict[str, Any] = { 'path': local_save_path, } - if composer_model.using_peft: + if self.using_peft: model_saving_kwargs['flavor'] = 'peft' model_saving_kwargs['save_pretrained_dir' ] = temp_save_dir @@ -591,15 +631,18 @@ def dtensor_to_tensor_hook( ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( ) with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. license_filename = _maybe_get_license_filename( local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', - None, - ), + self.pretrained_model_name, ) if license_filename is not None: mlflow_logger._mlflow_client.log_artifact( @@ -607,6 +650,8 @@ def dtensor_to_tensor_hook( os.path.join(local_save_path, license_filename), ) + self.pre_register_edit(local_save_path,) + # Spawn a new process to register the model. process = SpawnProcess( target=_register_model_with_run_id_multiprocess, diff --git a/llmfoundry/cli/data_prep_cli.py b/llmfoundry/cli/data_prep_cli.py index 731a9f06f0..130e0a6585 100644 --- a/llmfoundry/cli/data_prep_cli.py +++ b/llmfoundry/cli/data_prep_cli.py @@ -1,12 +1,18 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os from typing import Annotated, Optional +import psutil from typer import Option, Typer from llmfoundry.command_utils import ( convert_dataset_hf_from_args, + convert_dataset_json_from_args, + convert_delta_to_json_from_args, + convert_finetuning_dataset_from_args, + convert_text_to_mds_from_args, ) app = Typer(pretty_exceptions_show_locals=False) @@ -59,3 +65,204 @@ def convert_dataset_hf( no_wrap=no_wrap, num_workers=num_workers, ) + + +@app.command(name='convert_dataset_json') +def convert_dataset_json( + path: Annotated[str, Option(..., help='Path to the input data file')], + out_root: Annotated[str, Option(..., help='Output root directory')], + concat_tokens: Annotated[ + int, + Option( + ..., + help='Convert text to tokens and concatenate up to this many tokens', + )], + tokenizer: Annotated[str, Option(..., help='Tokenizer name')], + compression: Annotated[Optional[str], + Option(help='Compression type, if any')] = 'zstd', + split: Annotated[str, Option(help='Dataset split to process')] = 'train', + bos_text: Annotated[ + Optional[str], + Option(help='Text to insert at the beginning of each sequence')] = None, + eos_text: Annotated[ + Optional[str], + Option(help='Text to insert at the end of each sequence')] = None, + no_wrap: Annotated[ + bool, + Option(help='Do not wrap text across max_length boundaries')] = False, + num_workers: Annotated[ + Optional[int], + Option(help='Number of workers for data loading')] = None, +): + """Convert a dataset from JSON to MDS streaming format.""" + convert_dataset_json_from_args( + path=path, + split=split, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) + + +@app.command(name='convert_finetuning_dataset') +def convert_finetuning_dataset_cli( + dataset: Annotated[ + str, + Option( + ..., + help= + 'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`).', + )], + data_subset: Annotated[ + Optional[str], + Option(help='(Optional) subset of data to use.',)] = None, + splits: Annotated[str, + Option(help='Comma-separated list of dataset splits'), + ] = 'train,validation', + preprocessor: Annotated[ + Optional[str], + Option( + help= + 'Name or import path of function used to preprocess (reformat) the dataset.', + )] = None, + data_files: Annotated[ + str, Option(help='Data file for each split. Comma-separated.')] = '', + skip_preprocessing: Annotated[ + bool, Option(help='Whether to skip preprocessing.')] = False, + out_root: Annotated[ + str, + Option( + ..., + help= + 'Root path of output directory where MDS shards will be stored. Can be a remote URI.', + )] = '', + local: Annotated[ + Optional[str], + Option( + help= + '(Optional) root path of local directory if you want to keep a local copy when out_root is remote.', + )] = None, + compression: Annotated[ + Optional[str], + Option(help='(Optional) name of compression algorithm to use.')] = None, + num_workers: Annotated[Optional[int], + Option(help='Number of workers.')] = None, + tokenizer: Annotated[Optional[str], + Option(help='Tokenizer used for processing.')] = None, + tokenizer_kwargs: Annotated[ + Optional[str], + Option( + help= + 'Keyword arguments for tokenizer initialization in JSON format.', + )] = None, + max_seq_len: Annotated[int, Option(help='Maximum sequence length.')] = 2048, + target_prompts: Annotated[ + str, + Option(help='Policy for when to use prompts as training targets.'), + ] = 'none', + target_responses: Annotated[ + str, + Option(help='Policy for which responses to treat as training targets.'), + ] = 'last', + encoder_decoder: Annotated[ + bool, + Option( + help= + 'Set if the data are intended to be used to train an encoder-decoder model.', + )] = False, +): + """Convert a Finetuning Dataset to MDS streaming format.""" + # Convert comma-separated args + splits_list = splits.split(',') if splits else [] + data_files_list = data_files.split(',') if data_files else [] + convert_finetuning_dataset_from_args( + dataset=dataset, + data_subset=data_subset, + splits=splits_list, + preprocessor=preprocessor, + data_files=data_files_list, + skip_preprocessing=skip_preprocessing, + out_root=out_root, + local=local, + compression=compression, + num_workers=num_workers, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + max_seq_len=max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + encoder_decoder=encoder_decoder, + ) + + +@app.command(name='convert_text_to_mds') +def convert_text_to_mds( + output_folder: Annotated[str, Option(..., help='The folder to write output to')], + input_folder: Annotated[str, Option(..., help='The folder with text files to convert to MDS')], + concat_tokens: Annotated[int, Option(..., help='Convert text to tokens and concatenate up to this many tokens')], + tokenizer: Annotated[str, Option(..., help='The name of the tokenizer to use')], + bos_text: Annotated[Optional[str], Option(help='The text to prepend to each example to separate concatenated examples')] = None, + eos_text: Annotated[Optional[str], Option(help='The text to append to each example to separate concatenated examples')] = None, + compression: Annotated[str, Option(help='The compression algorithm to use for MDS writing')] = 'zstd', + use_tokenizer_eos: Annotated[bool, Option(help='Use the EOS text from the tokenizer')] = False, + no_wrap: Annotated[bool, Option(help='Whether to let text examples wrap across multiple training examples')] = False, + processes: Annotated[int, Option( + help='The number of processes to use to download and convert the dataset', + )] = min(max(psutil.cpu_count() - 2, 1), 32), # type: ignore + reprocess: Annotated[bool, Option( + help= + 'If true, reprocess the input_folder to MDS format. Otherwise, only reprocess upon changes to the input folder or dataset creation parameters.', + )] = False, + trust_remote_code: Annotated[bool, Option( + help='If true, allows custom code to be executed to load the tokenizer', + )] = False, + logging_level: Annotated[str, Option( + help='Logging level for the script. Default is INFO.', + )] = 'INFO', + +): + """Convert text files to MDS streaming format.""" + convert_text_to_mds_from_args( + output_folder=output_folder, + input_folder=input_folder, + compression=compression, + concat_tokens=concat_tokens, + tokenizer_name=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + use_tokenizer_eos=use_tokenizer_eos, + no_wrap=no_wrap, + processes=processes, + reprocess=reprocess, + trust_remote_code=trust_remote_code, + logging_level=logging_level, + ) + + +@app.command(name='convert_delta_to_json') +def convert_delta_to_json_cli( + delta_table_name: Annotated[str, Option(..., help='UC table ..')], + json_output_folder: Annotated[str, Option(..., help='Local path to save the converted json')], + http_path: Annotated[Optional[str], Option(help='If set, dbsql method is used')] = None, + batch_size: Annotated[int, Option(help='Row chunks to transmit a time to avoid OOM')] = 1 << 30, + processes: Annotated[int, Option(help='Number of processes allowed to use')] = os.cpu_count(), # type: ignore + cluster_id: Annotated[Optional[str], Option(help='Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.')] = None, + use_serverless: Annotated[bool, Option(help='Use serverless or not. Make sure the workspace is entitled with serverless')] = False, + json_output_filename: Annotated[str, Option(help='The name of the combined final jsonl that combines all partitioned jsonl')] = 'train-00000-of-00001.jsonl', +): + """Convert a Delta table into JSON files.""" + convert_delta_to_json_from_args( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_output_filename, + ) diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index adaaf03b6e..0226c4f408 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -4,6 +4,22 @@ convert_dataset_hf, convert_dataset_hf_from_args, ) +from llmfoundry.command_utils.data_prep.convert_dataset_json import ( + convert_dataset_json, + convert_dataset_json_from_args, +) +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + convert_delta_to_json_from_args, + fetch_DT, +) +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import ( + convert_finetuning_dataset, + convert_finetuning_dataset_from_args, +) +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( + convert_text_to_mds, + convert_text_to_mds_from_args, +) from llmfoundry.command_utils.eval import ( eval_from_yaml, evaluate, @@ -26,4 +42,12 @@ 'eval_from_yaml', 'convert_dataset_hf', 'convert_dataset_hf_from_args', + 'convert_dataset_json', + 'convert_dataset_json_from_args', + 'convert_finetuning_dataset_from_args', + 'convert_finetuning_dataset', + 'convert_text_to_mds', + 'convert_text_to_mds_from_args', + 'convert_delta_to_json_from_args', + 'fetch_DT', ] diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py new file mode 100644 index 0000000000..35d7e637e6 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -0,0 +1,222 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming dataset conversion scripts for json files.""" +import os +from enum import Enum +from glob import glob +from typing import Optional + +import datasets as hf_datasets +from streaming import MDSWriter +from torch.utils.data import IterableDataset +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.data import ConcatTokensDataset, NoConcatDataset + + +class ConcatMode(Enum): + NO_CONCAT = 'NO_CONCAT' + CONCAT_TOKENS = 'CONCAT_TOKENS' + + +def build_hf_dataset( + path: str, + split: str, + mode: ConcatMode, + max_length: Optional[int] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + tokenizer: PreTrainedTokenizerBase = None, +) -> IterableDataset: + """Build an IterableDataset over the HF C4 or pile source data. + + Args: + path (str): Dataset name + split (str): Split name. + mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS + max_length (int): The length of concatenated tokens + bos_text (str): text to insert at the beginning of each sequence + eos_text (str): text to insert at the end of each sequence + no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries + tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use + data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. + Typically "all" (The Pile) or "en" (c4). + + Returns: + An IterableDataset. + """ + if os.path.isdir(path): + data_files = glob(f'{path}/*') + else: + data_files = path + + hf_dataset = hf_datasets.load_dataset( + 'json', + data_files=data_files, + split=split, + ) + + if mode == ConcatMode.NO_CONCAT: + dataset = NoConcatDataset(hf_dataset) + else: + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) + if max_length is None: + raise ValueError(f'max_length must be set.') + if bos_text + eos_text == '': + test_tokens = tokenizer('test') + if test_tokens['input_ids'][ + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: + tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' + tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' + tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' + tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' + tok_error_msg += '--bos_text=<|endoftext|>.' + raise ValueError(tok_error_msg) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) + return dataset + + +def convert_dataset_json( + path: str, + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + split: str, + tokenizer: Optional[str] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + num_workers: Optional[int] = None, +) -> None: + """Create C4/pile streaming dataset. + + Args: + path (str): Path to the input data file + out_root (str): Output root directory + compression (Optional[str]): Compression type, if any + concat_tokens (Optional[int]): Convert text to tokens and concatenate up to this many tokens + split (str): Dataset split to process + tokenizer (Optional[str]): Tokenizer name + bos_text (str): Text to insert at the beginning of each sequence + eos_text (str): Text to insert at the end of each sequence + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers for data loading + """ + if concat_tokens is not None: + mode = ConcatMode.CONCAT_TOKENS + built_tokenizer = AutoTokenizer.from_pretrained(tokenizer) + # we will enforce length, so suppress warnings about sequences too long for the model + built_tokenizer.model_max_length = int(1e30) + columns = {'tokens': 'ndarray:int32'} + else: + mode = ConcatMode.NO_CONCAT + built_tokenizer = None + columns = {'text': 'str'} + + # Get samples + dataset = build_hf_dataset( + path=path, + split=split, + mode=mode, + max_length=concat_tokens, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + tokenizer=built_tokenizer, + ) + + print('here') + + # Write samples + print(f'Converting to MDS format...') + print( + f'Note that the progress bar is based on the dataset length before tokenization.', + ) + print(f'It will finish at a value below 100% if tokenizing') + with MDSWriter( + columns=columns, + out=os.path.join(out_root), + compression=compression, + ) as out: + for sample in tqdm(dataset): + out.write(sample) + + +def convert_dataset_json_from_args( + path: str, + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + split: str, + tokenizer: Optional[str] = None, + bos_text: Optional[str] = None, + eos_text: Optional[str] = None, + no_wrap: bool = False, + num_workers: Optional[int] = None, +) -> None: + """A wrapper for `convert_dataset_json` that parses arguments. + + Args: + path (str): Path to the input data file + out_root (str): Output root directory + compression (Optional[str]): Compression type, if any + concat_tokens (Optional[int]): Convert text to tokens and concatenate up to this many tokens + split (str): Dataset split to process + tokenizer (Optional[str]): Tokenizer name + bos_text (Optional[str]): Text to insert at the beginning of each sequence + eos_text (Optional[str]): Text to insert at the end of each sequence + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers for data loading + + Raises: + ValueError: If the out_root directory exists and contains files that overlap with the requested splits + ValueError: If concat_tokens is set and a tokenizer is not provided + """ + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(split)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {split}.', + ) + + # Make sure we have needed concat options + if ( + concat_tokens is not None and isinstance(concat_tokens, int) and + tokenizer is None + ): + ValueError( + 'When setting --concat_tokens, you must specify a --tokenizer', + ) + + # now that we have validated them, change BOS/EOS to strings + if bos_text is None: + bos_text = '' + if eos_text is None: + eos_text = '' + + convert_dataset_json( + path=path, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + split=split, + tokenizer=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..635efd54d4 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -0,0 +1,762 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import time +import urllib.parse +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +import google.protobuf.any_pb2 as any_pb2 +import pandas as pd +import pyarrow as pa +import requests +from composer.utils import retry +from packaging import version + +from llmfoundry.utils.exceptions import ( + ClusterDoesNotExistError, + FailedToConnectToDatabricksError, + FailedToCreateSQLConnectionError, +) + +if TYPE_CHECKING: + import pyspark.sql.connect.proto as pb2 + from databricks.sql.client import Connection as Connection + from databricks.sql.client import Cursor as Cursor + from pyspark.sql import SparkSession + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.types import Row + +try: + from pyspark.sql.connect.client.core import SparkConnectClient + spark_connect_client_installed = True +except ImportError: + spark_connect_client_installed = False + +try: + from pyspark.sql.connect.dataframe import DataFrame + data_frame_installed = True +except ImportError: + data_frame_installed = False + +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' + +TABLENAME_PATTERN = re.compile(r'(\S+)\.(\S+)\.(\S+)') + +log = logging.getLogger(__name__) + +Result = namedtuple( + 'Result', + [ + 'url', + 'row_count', + 'compressed_size', + 'uncompressed_size', + ], +) # pyright: ignore + +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. + + +def to_cf(self: 'SparkConnectClient', + plan: 'pb2.Plan', + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Executes the query plans and return as presigned URLS for cloud fetch. + + It can handle the current output formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + + Args: + self (SparkConnectClient): The SparkConnectClient we are processing. + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result has been truncated. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + import pyspark.sql.connect.proto as pb2 + import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise ValueError( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}', + ) + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + ), + ) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), + ) + + # Create the iterator + from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + ) + iterator = ExecutePlanResponseReattachableIterator( + req, + self._stub, + self._retry_policy, + self._builder.metadata(), + ) + # Iterate over the response + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR, + ): + batch = cloud_pb2.CloudResultBatch() + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError( + 'Response extension is not of type CloudResultBatch.', + ) + response.extension.Unpack(batch) + result += [ + Result( + b.url, + b.row_count, + b.compressed_size, + b.uncompressed_size, + ) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +if spark_connect_client_installed: + SparkConnectClient.to_cf = to_cf # pyright: ignore + + +def collect_as_cf(self: 'DataFrame', + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects DataFrame execution plan as presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + self (pd.DataFrame): The dataframe we are processing. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +if data_frame_installed: + DataFrame.collect_cf = collect_as_cf # pyright: ignore + + +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: + """Combine jsonl files in json_directory into one big jsonl file. + + This function does not work for nested subdirectories. + + Args: + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) + log.info('JSON files have been combined into a JSONL file.') + + +def run_query( + query: str, + method: str, + cursor: Optional['Cursor'] = None, + spark: Optional['SparkSession'] = None, + collect: bool = True, +) -> Optional[Union[List['Row'], 'DataFrame', 'SparkDataFrame']]: + """Run SQL query via databricks-connect or databricks-sql. + + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ + if method == 'dbsql': + if cursor is None: + raise ValueError(f'cursor cannot be None if using method dbsql') + cursor.execute(query) + if collect: + return cursor.fetchall() + elif method == 'dbconnect': + if spark == None: + raise ValueError(f'sparkSession is required for dbconnect') + df = spark.sql(query) + if collect: + return df.collect() + return df + else: + raise ValueError(f'Unrecognized method: {method}') + + +def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: + for i, r in enumerate(signed): + yield (i, r.url, json_output_folder, columns) + + +def download( + ipart: int, + url: str, + json_output_folder: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False, +) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_folder (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ + resp = requests.get(url) + if resp.status_code == 200: + if resp_format == 'json': + data = resp.json() + pd.DataFrame(data, columns=columns).to_json( + os.path.join( + json_output_folder, + 'part_' + str(ipart) + '.jsonl', + ), + orient='records', + lines=True, + ) + return + + # When resp_format is arrow: + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + import lz4.frame + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json( + os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False, + ) + + +def download_starargs(args: Tuple) -> None: + return download(*args) + + +def format_tablename(table_name: str) -> str: + """Escape catalog, schema and table names with backticks. + + This needs to be done when running SQL queries/setting spark sessions to prevent invalid identifier errors. + + Args: + table_name (str): catalog.scheme.tablename on UC + """ + match = re.match(TABLENAME_PATTERN, table_name) + + if match is None: + return table_name + + formatted_identifiers = [] + for i in range(1, 4): + identifier = f'`{match.group(i)}`' + formatted_identifiers.append(identifier) + + return '.'.join(formatted_identifiers) + + +def fetch_data( + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], + start: int, + end: int, + order_by: str, + tablename: str, + columns_str: str, + json_output_folder: str, +) -> None: + """Fetches a specified range of rows from a given table to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_folder (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + if method == 'dbconnect': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if spark_df is None: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None', + ) + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + if ans is None: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json( + os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True, + ) + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_total_rows( + tablename: str, + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], +): + ans = run_query( + f'SELECT COUNT(*) FROM {tablename}', + method, + cursor, + sparkSession, + ) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + return nrows + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_columns_info( + tablename: str, + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], +): + ans = run_query( + f'SHOW COLUMNS IN {tablename}', + method, + cursor, + sparkSession, + ) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + return columns, order_by, columns_str + + +def fetch( + method: str, + tablename: str, + json_output_folder: str, + batch_size: int = 1 << 30, + processes: int = 1, + sparkSession: Optional['SparkSession'] = None, + dbsql: Optional['Connection'] = None, +) -> None: + """Fetch UC delta table with databricks-connect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_folder (str): path to write the result json file to + batch_size (int): number of rows that dbsql fetches each time to avoid OOM + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session + """ + cursor = dbsql.cursor() if dbsql is not None else None + try: + nrows = get_total_rows( + tablename, + method, + cursor, + sparkSession, + ) + except Exception as e: + raise RuntimeError( + f'Error in get rows from {tablename}. Restart sparkSession and try again', + ) from e + + try: + columns, order_by, columns_str = get_columns_info( + tablename, + method, + cursor, + sparkSession, + ) + except Exception as e: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again', + ) from e + + if method == 'dbconnect' and sparkSession is not None: + log.info(f'{processes=}') + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) + + if cursor is not None: + cursor.close() + + +def validate_and_get_cluster_info( + cluster_id: Optional[str], + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False, +) -> tuple: + """Validate and get cluster info for running the Delta to JSONL conversion. + + Args: + cluster_id (str): cluster id to validate and fetch additional info for + databricks_host (str): databricks host name + databricks_token (str): databricks auth token + http_path (Optional[str]): http path to use for sql connect + use_serverless (bool): whether to use serverless or not + """ + method = 'dbsql' + dbsql = None + sparkSession = None + + if use_serverless: + method = 'dbconnect' + else: + if not cluster_id: + raise ValueError( + 'cluster_id is not set, however use_serverless is False', + ) + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + res = w.clusters.get(cluster_id=cluster_id) + if res is None: + raise ClusterDoesNotExistError(cluster_id) + + assert res.spark_version is not None + stripped_runtime = re.sub( + r'[a-zA-Z]', + '', + res.spark_version.split('-scala') + [0].replace( # type: ignore + 'x-snapshot', '', + ), + ) + runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) + if version.parse( + runtime_version, + ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', + ) + + if http_path is None and version.parse( + runtime_version, + ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + from databricks.connect import DatabricksSession + try: + if use_serverless: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + databricks_host, + ).token( + databricks_token, + ).header('x-databricks-session-id', session_id).getOrCreate() + + else: + if not cluster_id: + raise ValueError('cluster_id is needed for dbconnect.',) + sparkSession = DatabricksSession.builder.remote( + host=databricks_host, + token=databricks_token, + cluster_id=cluster_id, + ).getOrCreate() + + except Exception as e: + raise FailedToConnectToDatabricksError() from e + else: + try: + from databricks import sql + dbsql = sql.connect( + server_hostname=re.compile(r'^https?://').sub( + '', databricks_host).strip( + ), # sqlconnect hangs if hostname starts with https + http_path=http_path, + access_token=databricks_token, + ) + except Exception as e: + raise FailedToCreateSQLConnectionError() from e + return method, dbsql, sparkSession + + +def fetch_DT( + delta_table_name: str, + json_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + DATABRICKS_HOST: str, + DATABRICKS_TOKEN: str, + batch_size: int = 1 << 30, + processes: int = os.cpu_count(), # type: ignore + json_output_filename: str = 'train-00000-of-00001.jsonl', +) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(json_output_folder) + if obj.scheme != '': + raise ValueError( + 'Check the json_output_folder and verify it is a local path!', + ) + + if os.path.exists(json_output_folder): + if not os.path.isdir(json_output_folder) or os.listdir( + json_output_folder, + ): + raise RuntimeError( + f'Output folder {json_output_folder} already exists and is not empty. Please remove it and retry.', + ) + + os.makedirs(json_output_folder, exist_ok=True) + + if not json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {json_output_folder} created.') + + # validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=cluster_id, + databricks_host=DATABRICKS_HOST, + databricks_token=DATABRICKS_TOKEN, + http_path=http_path, + use_serverless=use_serverless, + ) + + formatted_delta_table_name = format_tablename(delta_table_name) + + fetch( + method, + formatted_delta_table_name, + json_output_folder, + batch_size, + processes, + sparkSession, + dbsql, + ) + + if dbsql is not None: + dbsql.close() + + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + json_output_folder, + os.path.join(json_output_folder, json_output_filename), + ) + + +def _check_imports(): + try: + import lz4.frame + _ = lz4.frame + except ImportError as e: + raise ImportError('lz4 is not installed.') from e + + try: + from databricks.connect import DatabricksSession + _ = DatabricksSession + except ImportError as e: + raise ImportError( + 'databricks-connect is not installed or improperly configured.', + ) from e + + try: + from databricks import sql + from databricks.sdk import WorkspaceClient + from databricks.sql.client import Connection as Connection + from databricks.sql.client import Cursor as Cursor + _ = WorkspaceClient, Connection, Cursor, sql + except ImportError as e: + raise ImportError( + 'databricks-sdk is not installed or improperly configured.', + ) from e + + try: + import pyspark.sql.connect.proto as pb2 + import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 + from pyspark.sql import SparkSession + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + ) + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.types import Row + _ = ( + pb2, + cloud_pb2, + SparkSession, + SparkConnectClient, + ExecutePlanResponseReattachableIterator, + DataFrame, + SparkDataFrame, + Row, + ) + except ImportError as e: + raise ImportError( + 'pyspark is not installed or improperly configured.', + ) from e + + +def convert_delta_to_json_from_args( + delta_table_name: str, + json_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + batch_size: int, + processes: int, + json_output_filename: str, +) -> None: + """A wrapper for `convert_dataset_json` that parses arguments. + + Args: + delta_table_name (str): UC table ..
+ json_output_folder (str): Local path to save the converted json + http_path (Optional[str]): If set, dbsql method is used + batch_size (int): Row chunks to transmit a time to avoid OOM + processes (int): Number of processes allowed to use + cluster_id (Optional[str]): Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect. + use_serverless (bool): Use serverless or not. Make sure the workspace is entitled with serverless + json_output_filename (str): The name of the combined final jsonl that combines all partitioned jsonl + """ + _check_imports() + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + DATABRICKS_HOST = w.config.host + DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_output_filename, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + log.info(f'Elapsed time {time.time() - tik}') diff --git a/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py new file mode 100644 index 0000000000..94cd79815b --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py @@ -0,0 +1,346 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import platform +import warnings +from typing import Any, Callable, Dict, Iterable, Optional, Union + +import datasets as hf_datasets +import psutil +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +from streaming import MDSWriter +from torch.utils.data import DataLoader +from tqdm import tqdm + +from llmfoundry.data.finetuning.collator import validate_target_settings +from llmfoundry.data.finetuning.tasks import ( + _get_example_type, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example, +) +from llmfoundry.utils.builders import build_tokenizer + +HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] + + +def build_dataloader( + dataset: HFDataset, + batch_size: int, + num_workers: Optional[int] = None, +) -> DataLoader: + if num_workers is None: + # Multiple workers is only supported on linux machines + if 'linux' in platform.platform().lower(): + num_workers = max(1, psutil.cpu_count()) + else: + num_workers = 0 + + # If using multiple workers, configure each worker to prefetch as many samples as it can, up to + # the aggregate device batch size + # If not using workers, the torch DataLoader expects the default value for prefetch_factor, + # which non-intuitively must be 2. + # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero + if 'macos' in platform.platform().lower() and num_workers == 0: + prefetch_factor = None + else: + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 + + return DataLoader( + dataset=dataset, + sampler=None, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + +def generate_samples( + loader: DataLoader, + truncate_num_samples: Optional[int] = None, +) -> Iterable[Dict[str, bytes]]: + """Generator over samples of a dataloader. + + Args: + loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} + truncate_num_samples (Optional[int]): An optional # of samples to stop at. + + Yields: + Sample dicts. + """ + n_samples = 0 + for batch in loader: + keys = list(batch.keys()) + current_bs = len(batch[keys[0]]) + for idx in range(current_bs): + if truncate_num_samples is not None and n_samples == truncate_num_samples: + return + n_samples += 1 + yield {k: v[idx] for k, v in batch.items()} + + +def get_columns_and_format( + dataset: HFDataset, + tokenizing: bool, + preprocessing_fn: Callable, +): + ex = preprocessing_fn(next(iter(dataset))) + example_type = _get_example_type(ex) + if tokenizing: + return {'turns': 'json'}, example_type + if example_type == 'chat': + # Chat format + return {'messages': 'json'}, example_type + else: + # Prompt-response format + return {'prompt': 'str', 'response': 'str'}, example_type + + +def convert_finetuning_dataset( + dataset: str, + data_subset: Optional[str], + splits: list[str], + preprocessor: Optional[str], + data_files: list[str], + skip_preprocessing: bool, + out_root: str, + local: Optional[str], + compression: Optional[str], + num_workers: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: dict[str, Any], + max_seq_len: int, + target_prompts: str, + target_responses: str, + encoder_decoder: bool, +) -> None: + """Converts Finetuning datasets to MDS format. + + Args: + dataset (str): Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`). + data_subset (Optional[str]): Subset of data to use. + splits (list[str]): Comma-separated list of dataset splits + preprocessor (Optional[str]): Name or import path of function used to preprocess (reformat) the dataset. + data_files (list[str]): Data file for each split. Comma-separated. + skip_preprocessing (bool): Whether to skip preprocessing. + out_root (str): Root path of output directory where MDS shards will be stored. Can be a remote URI. + local (Optional[str]): Root path of local directory if you want to keep a local copy when out_root is remote. + compression (Optional[str]): Name of compression algorithm to use. + num_workers (Optional[int]): Number of workers. + tokenizer (Optional[str]): Tokenizer used for processing. + tokenizer_kwargs (dict[str, Any]): Keyword arguments for tokenizer initialization. + max_seq_len (int): Maximum sequence length. + target_prompts (str): Policy for when to use prompts as training targets. + target_responses (str): Policy for which responses to treat as training targets. + encoder_decoder (bool): Set if the data are intended to be used to train an encoder-decoder model + + Raises: + ValueError: If the target settings are invalid. + """ + if skip_preprocessing: + preprocessing_fn = lambda x: x # Just an identity function + else: + preprocessor_str = preprocessor + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( + preprocessor=preprocessor_str, + dataset_name=dataset, + ) + if preprocessing_fn is None: + raise ValueError( + '`preprocessor` was not set and no preprocessing function ' +\ + 'has been registered for `dataset`. If this was intentional ' +\ + '(e.g., because your dataset is already correctly formatted), ' +\ + 'include the "--skip-preprocessing" flag to avoid this error.', + ) + + # Make sure the target settings are valid + validate_target_settings( + target_prompts=target_prompts, + target_responses=target_responses, + decoder_only_format=not encoder_decoder, + ) + + tokenizer = None + tokenizer_kwargs = tokenizer_kwargs + tokenizer_kwargs.update({'model_max_length': max_seq_len}) + if tokenizer: + tokenizer = build_tokenizer(tokenizer, tokenizer_kwargs) + + for i, split_name in enumerate(splits): + data_file = None + if len(data_files) > 0: + data_file = data_files[i] + loaded_dataset = hf_datasets.load_dataset( + path=dataset, + name=data_subset, + split=split_name, + data_files=data_file, + streaming=True, + ) + # Determine the output columns + columns, example_type = get_columns_and_format( + dataset=loaded_dataset, + tokenizing=tokenizer is not None, + preprocessing_fn=preprocessing_fn, + ) + # Prepare the iterables + if example_type == 'chat': + samples = iter(loaded_dataset) + else: + loader = build_dataloader( + dataset=loaded_dataset, + batch_size=512, + num_workers=num_workers, + ) + samples = generate_samples(loader) + + # Write samples + print(f'Converting {split_name} to MDS format...') + out = os.path.join(out_root, split_name) + if local is not None: + out = (os.path.join(local, split_name), out) + keep_local = True + else: + keep_local = False + with MDSWriter( + columns=columns, + out=out, + compression=compression, + keep_local=keep_local, + ) as out: + examples_removed = 0 + for sample in tqdm(samples, desc=split_name): + formatted_sample = preprocessing_fn(sample) + assert isinstance(formatted_sample, dict) + + # Use the _get_example_type utility to confirm that the formatted sample + # can be interpreted by the tokenization code + try: + example_type = _get_example_type(formatted_sample) + except Exception as e: + raise ValueError( + 'Encountered an error when checking example for proper formatting. ' +\ + f'example={formatted_sample}', + ) from e + if tokenizer is not None: + sample = tokenize_formatted_example( + formatted_sample, + tokenizer=tokenizer, + ) + if not is_valid_ift_example( + max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + decoder_only_format=not encoder_decoder, + example=sample, + ): + examples_removed += 1 + continue + + sample_to_write = {'turns': []} + for turn in sample['turns']: + turn_to_write = {} + for key in ['input_ids', 'labels']: + turn_to_write[key] = list(turn[key]) + sample_to_write['turns'].append(turn_to_write) + out.write(sample_to_write) + else: + if example_type == 'prompt_response': + encoded_sample = {} + for key in ['prompt', 'response']: + value = formatted_sample[key] + assert isinstance(value, str) + encoded_sample[key] = value.encode('utf-8') + out.write(encoded_sample) + else: + out.write(formatted_sample) + + if tokenizer is not None and examples_removed > 0: + warnings.warn( + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.', + ) + + +def convert_finetuning_dataset_from_args( + dataset: str, + data_subset: Optional[str], + splits: list[str], + preprocessor: Optional[str], + data_files: list[str], + skip_preprocessing: bool, + out_root: str, + local: Optional[str], + compression: Optional[str], + num_workers: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: Optional[str], + max_seq_len: int, + target_prompts: str, + target_responses: str, + encoder_decoder: bool, +): + """A wrapper for `convert_finetuning_dataset` to parse arguments. + + Args: + dataset (str): Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`). + data_subset (Optional[str]): Subset of data to use. + splits (list[str]): Comma-separated list of dataset splits + preprocessor (Optional[str]): Name or import path of function used to preprocess (reformat) the dataset. + data_files (list[str]): Data file for each split. Comma-separated. + skip_preprocessing (bool): Whether to skip preprocessing. + out_root (str): Root path of output directory where MDS shards will be stored. Can be a remote URI. + local (Optional[str]): Root path of local directory if you want to keep a local copy when out_root is remote. + compression (Optional[str]): Name of compression algorithm to use. + num_workers (Optional[int]): Number of workers. + tokenizer (Optional[str]): Tokenizer used for processing. + tokenizer_kwargs (Optional[str]): Keyword arguments for tokenizer initialization in JSON format. + max_seq_len (int): Maximum sequence length. + target_prompts (str): Policy for when to use prompts as training targets. + target_responses (str): Policy for which responses to treat as training targets. + encoder_decoder (bool): Set if the data are intended to be used to train an encoder-decoder model. + + Raises: + ValueError: If the target settings are invalid. + ValueError: If the output directory already contains the requested splits. + """ + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(splits)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {splits}.', + ) + + if tokenizer_kwargs is not None: + parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs) + else: + parsed_tokenizer_kwargs = {} + + if len(data_files) > 0 and len(data_files,) != len(splits): + raise ValueError( + f'If data_files is set, data_files and splits must have the same length. Got {len(data_files)=} while {len(splits)=}', + ) + convert_finetuning_dataset( + dataset=dataset, + data_subset=data_subset, + splits=splits, + preprocessor=preprocessor, + data_files=data_files, + skip_preprocessing=skip_preprocessing, + out_root=out_root, + local=local, + compression=compression, + num_workers=num_workers, + tokenizer=tokenizer, + tokenizer_kwargs=parsed_tokenizer_kwargs, + max_seq_len=max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + encoder_decoder=encoder_decoder, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py new file mode 100644 index 0000000000..94bdc16526 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -0,0 +1,596 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import math +import os +import tempfile +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from glob import glob +from typing import Dict, Iterable, List, Optional, Tuple, cast + +import numpy as np +from composer.utils import ( + ObjectStore, + maybe_create_object_store_from_uri, + parse_uri, +) +from numpy.typing import NDArray +from streaming import MDSWriter +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.data.data import AbstractConcatTokensDataset +from llmfoundry.utils.data_prep_utils import ( + DownloadingIterable, + download_file, + merge_shard_groups, +) +from llmfoundry.utils.exceptions import ( + DatasetTooSmallError, + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) + +log = logging.getLogger(__name__) + +DONE_FILENAME = '.text_to_mds_conversion_done' + + +class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): + """An IterableDataset that returns token samples for MDSWriter from files. + + Returns dicts of {'tokens': ndarray:int32} + + Each file is considered a sequence. + """ + + def __init__( + self, + files: Iterable[str], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + bos_text: str, + eos_text: str, + no_wrap: bool, + ): + self.files = files + super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) + log.info(f'Initialized ConcatTokensFromFilesDataset.') + + def __iter__(self) -> Iterable[Dict[str, NDArray]]: + log.info( + 'Starting iteration over files in ConcatTokensFromFilesDataset', + ) + buffer = [] + for file in self.files: + log.info(f'Processing file: {file}') + with open(file, 'r') as f: + buffer += self.bos_tokens + first_chunk = True + # Read the file in 1MB chunks to avoid memory issues + for chunk in iter(partial(f.read, 1000000), ''): + # Tokenize the chunk + encoded = self.tokenizer( + chunk, + truncation=False, + padding=False, + ) + iids = encoded['input_ids'] + + # If this is not the first chunk, remove the BOS token + if not first_chunk: + if iids[0] == self.tokenizer.bos_token_id: + iids = iids[1:] + + # Add the tokens to the buffer + buffer += iids + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self. + max_length:] if self.should_wrap else [] + yield { + 'tokens': np.asarray(concat_sample, dtype=np.int32), + } + + first_chunk = False + + # Add the EOS token to the buffer to separate files. + buffer += self.eos_tokens + + # Yield any remaining samples of size max_length. + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self.max_length:] if self.should_wrap else [] + yield {'tokens': np.asarray(concat_sample, dtype=np.int32)} + + log.info( + 'Finished iterating over files in ConcatTokensFromFilesDataset', + ) + + +def get_object_names(input_folder: str) -> List[str]: + """Get object names from a local or remote folder. + + Args: + input_folder (str): local or remote folder path. + """ + object_store = maybe_create_object_store_from_uri(input_folder) + if object_store is not None: + _, _, folder_prefix = parse_uri(input_folder) + names = [ + name for name in object_store.list_objects(folder_prefix) + if name.endswith('.txt') + ] + log.info(f'Found {len(names)} text files in remote storage') + else: + # input_folder is a local folder + names = [ + text_file for dirpath, _, _ in os.walk(input_folder) + for text_file in glob(os.path.join(dirpath, '*.txt')) + ] + # return names, sizes + log.info(f'Found {len(names)} text files at {input_folder}') + + return names + + +def get_task_args( + object_names: List[str], + output_root: str, + input_folder: str, + n_groups: int, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +) -> Iterable: + """Get download_and_convert arguments split across n_groups. + + Each group handles a portion of object_names. + + Args: + object_names (List[str]): Names of objects to process + output_root (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + n_groups (int): Number of groups to split the object names into + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info( + f'Preparing task arguments for {len(object_names)} objects across {n_groups} groups', + ) + num_objects = len(object_names) + objs_per_group = math.ceil(num_objects / n_groups) + for group, i in enumerate(range(0, num_objects, objs_per_group)): + output_subdir = os.path.join(output_root, str(group)) + log.info( + f'Created task for group {group} with {min(objs_per_group, num_objects - i)} objects', + ) + yield ( + object_names[i:min(i + objs_per_group, num_objects)], + output_subdir, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + +def download_and_convert_starargs(args: Tuple): + """Helper function to call download_and_convert with star args. + + This helps us use download_and_convert with multiprocessing. + """ + return download_and_convert(*args) + + +def download_and_convert( + file_names: List[str], + output_folder: str, + input_folder: str, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +): + """Downloads and converts text files to MDS format. + + Args: + file_names (List[str]): Files to process + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info(f'Starting download and conversion for {len(file_names)} files') + + object_store = maybe_create_object_store_from_uri(input_folder) + + # Download file_names + with tempfile.TemporaryDirectory() as tmp_dir: + log.info(f'Created temporary directory: {tmp_dir}') + downloading_iter = DownloadingIterable( + object_names=file_names, + output_folder=tmp_dir, + object_store=object_store, + ) + log.info(f'Initializing tokenizer: {tokenizer_name}') + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + + # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up + # to the maximum sequence length + dataset = ConcatTokensFromFilesDataset( + files=downloading_iter, + max_length=concat_tokens, + tokenizer=tokenizer, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + ) + + columns = {'tokens': 'ndarray:int32'} + + log.info('Converting to MDS format...') + with MDSWriter( + out=output_folder, + columns=columns, + compression=compression, + ) as out: + for sample in tqdm(dataset): + out.write(sample) + + log.info(f'Completed download and conversion for {len(file_names)} files') + + +def is_remote_path(path: str) -> bool: + """Checks whether a path is a remote path. + + Args: + path (str): path to check + """ + backend, _, _ = parse_uri(path) + return backend != '' + + +def is_already_processed( + output_root: str, + args_str: str, + object_names: List[str], +) -> bool: + """Determines whether a group of text files has already been processed. + + Checks the done fie at output root to determine this. + + Args: + output_root (str): Output folder where a done file may exist + args_str (str): String representation of the arguments + object_names (List[str]): Names of objects to convert to MDS format + """ + log.info( + f'Checking if {len(object_names)} objects have already been processed in {output_root}', + ) + + # Retrieve the done file contents + output_object_store = maybe_create_object_store_from_uri(output_root) + if output_object_store is not None: + # Download and read the done file from the remote object store + _, _, output_folder_prefix = parse_uri(output_root) + try: + with tempfile.TemporaryDirectory() as tmp_dir: + done_file = os.path.join(tmp_dir, DONE_FILENAME) + download_file( + object_store=output_object_store, + object_name=os.path.join( + output_folder_prefix, + DONE_FILENAME, + ), + output_filename=done_file, + ) + with open(done_file) as df: + done_file_contents = df.read().splitlines() + log.info(f'Retrieved done file contents from remote storage') + except FileNotFoundError: + log.info('Done file not found in remote storage') + return False + else: + # Read the local done file + done_file = os.path.join(output_root, DONE_FILENAME) + if not os.path.isfile(done_file): + log.info('Done file not found in local storage') + return False + with open(done_file) as df: + done_file_contents = df.read().splitlines() + log.info(f'Retrieved done file contents from local storage') + + # Compare the arguments + prev_args_str = done_file_contents[0] + if prev_args_str != args_str: + log.info('Arguments have changed, reprocessing required') + return False + + # Compare file names + prev_names = done_file_contents[1:] + if len(prev_names) != len(object_names): + log.info('Number of files has changed, reprocessing required') + return False + for idx, prev_name in enumerate(prev_names): + if object_names[idx] != prev_name: + log.info('File names have changed, reprocessing required') + return False + + log.info('All files have already been processed') + return True + + +def write_done_file(folder: str, args_str: str, object_names: List[str]): + """Write a file to signify completion. + + This the done file includes the arguments to processing and + a list of objects that were processed. + + Args: + folder (str): Folder to write the done file to + args_str (str): String representation of arguments + object_names (List[str]): List of objects to convert to MDS format + """ + with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: + log.info(f'Writing done file.') + done_file.write('\n'.join([args_str] + object_names) + '\n') + log.info(f'Done file written successfully') + + +def convert_text_to_mds( + tokenizer_name: str, + output_folder: str, + input_folder: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + processes: int, + args_str: str, + reprocess: bool, + trust_remote_code: bool, +): + """Convert a folder of text files to MDS format. + + Args: + tokenizer_name (str): Name of tokenizer to use + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + processes (int): The number of processes to use. + args_str (str): String representation of the arguments + reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + # Load the tokenizer once on the main process so that the files are cached to avoid race conditions + # in the Hugging Face load code + AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + + is_remote_output = is_remote_path(output_folder) + log.info(f'Output is remote: {is_remote_output}') + + object_names = get_object_names(input_folder) + if len(object_names) == 0: + log.error(f'No text files found in input folder: {input_folder}') + raise InputFolderMissingDataError(input_folder) + + # Check if the text files in the bucket have already been processed. + if not reprocess and is_already_processed( + output_folder, + args_str, + object_names, + ): + log.info( + f'Input folder {input_folder} is already processed at {output_folder} and ' + + + 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', + ) + return + + # Use a temporary local directory if the output is remote and there are more than 1 processes + local_output_folder = tempfile.TemporaryDirectory( + ).name if is_remote_output else output_folder + log.info(f'Using local output folder: {local_output_folder}') + + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + log.error(f'Output folder is not empty: {output_folder}') + raise OutputFolderNotEmptyError(output_folder) + + if processes > 1: + log.info(f'Using multiprocessing with {processes} processes') + # Download and convert the text files in parallel + args = get_task_args( + object_names, + local_output_folder, + input_folder, + processes, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_and_convert_starargs, args)) + + log.info('Merging MDS shards from each process') + # Merge the mds shards from each of the processes into a single folder + merge_shard_groups(local_output_folder) + else: + log.info('Using single process for download and conversion') + download_and_convert( + object_names, + local_output_folder, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + index_path = os.path.join(local_output_folder, 'index.json') + with open(index_path, 'r') as index_file: + if not json.load(index_file)['shards']: + raise DatasetTooSmallError() + + # Write a done file with the args and object names + write_done_file(local_output_folder, args_str, object_names) + + if is_remote_output: + # Upload the local output to the remote location + output_object_store = cast( + ObjectStore, + maybe_create_object_store_from_uri(output_folder), + ) + _, _, output_folder_prefix = parse_uri(output_folder) + files_to_upload = os.listdir(local_output_folder) + + for file in files_to_upload: + assert not os.path.isdir(file) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object( + remote_path, + os.path.join(local_output_folder, file), + ) + + +def _configure_logging(logging_level: str): + """Configure logging. + + Args: + logging_level (str): Logging level. + """ + logging.basicConfig( + format= + f'%(asctime)s: [%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging_level = logging_level.upper() + logging.getLogger('llmfoundry').setLevel(logging_level) + logging.getLogger(__name__).setLevel(logging_level) + log.info(f'Logging level set to {logging_level}') + + +def convert_text_to_mds_from_args( + output_folder: str, + input_folder: str, + compression: str, + concat_tokens: int, + tokenizer_name: str, + bos_text: Optional[str], + eos_text: Optional[str], + use_tokenizer_eos: bool, + no_wrap: bool, + processes: int, + reprocess: bool, + trust_remote_code: bool, + logging_level: str, +) -> None: + """A wrapper for `convert_text_to_mds` to parse arguments. + + Args: + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + compression (str): The compression algorithm to use for MDS writing + concat_tokens (int): Concatenate up to this many tokens + tokenizer_name (str): The name of the tokenizer to use + bos_text (Optional[str]): The text to prepend to each example to separate concatenated examples + eos_text (Optional[str]): The text to append to each example to separate concatenated examples + use_tokenizer_eos (bool): Use the EOS text from the tokenizer + no_wrap (bool): Whether to let text examples wrap across multiple training examples + processes (int): The number of processes to use to download and convert the dataset + reprocess (bool): If true, reprocess the input_folder to MDS format. Otherwise, only reprocess upon changes to the input folder or dataset creation parameters. + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + logging_level (str): Logging level for the script. Default is INFO. + + Raises: + ValueError: If `use_tokenizer_eos` is True and `eos_text` is not None + """ + if use_tokenizer_eos: + # Ensure that eos text is not specified twice. + if eos_text is not None: + ValueError( + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', + ) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + eos_text = tokenizer.eos_token + + # now that we have validated them, change BOS/EOS to strings + if bos_text is None: + bos_text = '' + if eos_text is None: + eos_text = '' + _configure_logging(logging_level) + + # Define args for _args_str + args = { + 'tokenizer': tokenizer_name, + 'output_folder': output_folder, + 'input_folder': input_folder, + 'compression': compression, + 'concat_tokens': concat_tokens, + 'eos_text': eos_text, + 'bos_text': bos_text, + 'no_wrap': no_wrap, + 'processes': processes, + 'reprocess': reprocess, + 'trust_remote_code': trust_remote_code, + } + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=output_folder, + input_folder=input_folder, + concat_tokens=concat_tokens, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + compression=compression, + processes=processes, + reprocess=reprocess, + trust_remote_code=trust_remote_code, + args_str=str(args), + ) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index 7d8306c0a0..bddd592dba 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -175,6 +175,54 @@ def evaluate_model( return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) +def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]: + """Transform the config to allow top-level keys for model configuration. + + This function allows users to use the 'train.py' syntax in 'eval.py'. + It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys + into the nested 'models' list format required by 'eval.py'. + + Input config format (train.py style): + ```yaml + model: + + load_path: /path/to/checkpoint + tokenizer: + + ``` + + Output config format (eval.py style): + ```yaml + models: + - model: + + tokenizer: + + load_path: /path/to/checkpoint + ``` + """ + if 'model' in cfg: + if 'models' in cfg: + raise ValueError( + 'Please specify either model or models in the config, not both', + ) + default_name = cfg.get('model').get('name') # type: ignore + model_cfg = { + 'model': cfg.pop('model'), + 'tokenizer': cfg.pop('tokenizer', None), + 'model_name': cfg.pop('model_name', default_name), + } + if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None: + raise ValueError( + 'When specifying model, "tokenizer" must be provided in the config', + ) + if 'load_path' in cfg: + model_cfg['load_path'] = cfg.pop('load_path') + cfg['models'] = [model_cfg] + + return cfg + + def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]: # Run user provided code if specified for code_path in cfg.get('code_paths', []): @@ -184,6 +232,7 @@ def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]: cfg, EvalConfig, EVAL_CONFIG_KEYS, + transforms=[allow_toplevel_keys], icl_tasks_required=True, ) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index f49fb28801..c925e6e586 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -36,8 +36,10 @@ build_callback, build_composer_model, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_scheduler, build_tokenizer, ) @@ -256,6 +258,31 @@ def train(cfg: DictConfig) -> Trainer: # Optional fsdp data, fine-tuning, and eval configs fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + if fsdp_config is not None: + if 'load_planner' in fsdp_config: + load_planners = list(fsdp_config['load_planner'].items()) + if len(load_planners) > 1: + raise ValueError( + 'Only one load planner can be specified in the config.', + ) + load_planner_name, load_planner_config = load_planners[0] + fsdp_config['load_planner'] = build_load_planner( + load_planner_name, + **load_planner_config, + ) + + if 'save_planner' in fsdp_config: + save_planners = list(fsdp_config['save_planner'].items()) + if len(save_planners) > 1: + raise ValueError( + 'Only one save planner can be specified in the config.', + ) + save_planner_name, save_planner_config = save_planners[0] + fsdp_config['save_planner'] = build_save_planner( + save_planner_name, + **save_planner_config, + ) + eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str @@ -517,6 +544,7 @@ def train(cfg: DictConfig) -> Trainer: dist_timeout=train_cfg.dist_timeout, profiler=profiler, compile_config=compile_config, + spin_dataloaders=train_cfg.spin_dataloaders, ) # Optionally just save an HF checkpoint diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 11104ac706..771033a703 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import inspect import logging import os from typing import Any, Dict, Optional, Tuple, Union @@ -17,6 +18,8 @@ validate_target_settings, ) from llmfoundry.data.finetuning.tasks import ( + DEFAULT_TARGET_PROMPTS, + DEFAULT_TARGET_RESPONSES, DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor, @@ -39,9 +42,15 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -# Default settings to use for target responses and target prompts -_DEFAULT_TARGET_RESPONSES = 'last' -_DEFAULT_TARGET_PROMPTS = 'none' +# Extra keys present in the dataset config dictionary beyond the constructor keys +_ALLOWED_DATASET_KEYS = { + 'shuffle', + 'packing_ratio', + 'allow_pad_trimming', + 'seq_parallel_replication', + 'auto_packing_replication', + 'max_leftover_bins_to_keep', +} def build_finetuning_dataloader( @@ -64,9 +73,12 @@ def build_finetuning_dataloader( on which you intend to use, as explained below. Args: - name (str): The type of dataloader to build. Must = "finetuning". - --- - *** HuggingFace dataset config fields *** + tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to + prepare the data from raw text. Any missing sentinel tokens will + be added by the collator. + device_batch_size (int, float): The size of the batches (number of examples) + that the dataloader will produce. + dataset (Dict[str, Any]): A HuggingFace dataset config which contains the following fields: dataset.hf_name (str, optional): The name of the HuggingFace dataset to use. Can also be a remote http(s) directory or object store bucket containing the file {split}.jsonl in the format (prompt, response), @@ -130,16 +142,32 @@ def build_finetuning_dataloader( The script `scripts/misc/profile_packing.py` can help you choose the best packing_ratio. dataset.shuffle (bool): Whether to shuffle the dataset. - ___ See :class:`StreamingFinetuningDataset` for info on other standard config options within `dataset` that will be passed as kwargs if using the streaming codepath. - --- - tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to - prepare the data from raw text. Any missing sentinel tokens will - be added by the collator. - device_batch_size (int, float): The size of the batches (number of examples) - that the dataloader will produce. + num_workers (int, optional): How many subprocesses to use for data loading. + 0 means that the data will be loaded in the main process. The default is 0. + This argument is passed directly to the pytorch :class:`DataLoader`. + drop_last (bool, optional): If true, drop the last incomplete batch, if the dataset + size is not divisible by the batch size. If False and the size of dataset is + not divisible by the batch size, then the last batch will be smaller. The + default is False. This argument is passed directly to the pytorch :class:`DataLoader`. + pin_memory (bool, optional): If True, the data loader will copy Tensors into device/CUDA + pinned memory before returning them. If your data elements are a custom type, or your + `collate_fn` returns a batch that is a custom type. This argument is passed directly to + the pytorch :class:`DataLoader`. + prefetch_factor (int, optional): Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + (default value depends on the set value for num_workers. If value of num_workers=0 default + is None. Otherwise, if value of num_workers > 0 default is 2). This argument is passed + directly to the pytorch :class:`DataLoader`. + persistent_workers (bool, optional): If True, the data loader will not shut down the worker + processes after a dataset has been consumed once. This allows to maintain the workers + Dataset instances alive. The default is False. This argument is passed directly to the + pytorch :class:`DataLoader`. + timeout (int, optional): If positive, the timeout value for collecting a batch from workers. + Should always be non-negative. The default is 0. This argument is passed directly to the + pytorch :class:`DataLoader`. See :class:`DataLoader` for standard argument options to the pytorch dataloader, such as `drop_last`, `num_workers`, etc. @@ -152,7 +180,26 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset - _validate_config(**dataset_cfg) + is_streaming = ( + dataset_cfg.get('remote') is not None or + dataset_cfg.get('streams') is not None + ) + if is_streaming: + dataset_constructor_keys = inspect.signature( + dataset_constructor.streaming_dataset_class, + ).parameters.keys() + else: + dataset_constructor_keys = inspect.signature( + dataset_constructor.build_from_hf, + ).parameters.keys() + + allowed_dataset_config_keys = set( + dataset_constructor_keys, + ).union(_ALLOWED_DATASET_KEYS) + _validate_config( + **dataset_cfg, + allowed_dataset_keys=allowed_dataset_config_keys, + ) # Use EOS as the pad token if none exists if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok) @@ -194,9 +241,7 @@ def build_finetuning_dataloader( streaming_dataset = None # for pyright sampler = None - if dataset_cfg.get( - 'remote', - ) is not None or dataset_cfg.get('streams') is not None: + if is_streaming: # Build streaming dataloader streams_cfg = dataset_cfg.get('streams', None) streams_cfg = to_dict_container( @@ -206,34 +251,20 @@ def build_finetuning_dataloader( streams_cfg, ) if streams_cfg is not None else None - # note: we don't need to use ** here because we're setting default values for almost all arguments + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'streams', 'packing_ratio'} + } streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, - local=dataset_cfg.get('local', None), - remote=dataset_cfg.get('remote', None), - split=dataset_cfg.get('split', None), - download_retry=dataset_cfg.get('download_retry', 2), - download_timeout=dataset_cfg.get('download_timeout', 60), - validate_hash=dataset_cfg.get('validate_hash', None), - keep_zip=dataset_cfg.get('keep_zip', False), - epoch_size=dataset_cfg.get('epoch_size', None), - predownload=dataset_cfg.get('predownload', None), - cache_limit=dataset_cfg.get('cache_limit', None), - partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), - num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), batch_size=dataloader_batch_size, - shuffle=dataset_cfg.get('shuffle', False), - shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), - shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), - shuffle_block_size=dataset_cfg.get('shuffle_block_size', None), - sampling_method=dataset_cfg.get('sampling_method', 'balanced'), - sampling_granularity=dataset_cfg.get('sampling_granularity', 1), - batching_method=dataset_cfg.get('batching_method', 'random'), - max_seq_len=dataset_cfg['max_seq_len'], - allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, packing_ratio=dataloader_batch_size / dataset_batch_size, + **dataset_constructor_args, ) else: @@ -264,24 +295,19 @@ def build_finetuning_dataloader( dataset_name_or_path, ) - # Build dataset from HF. + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'split', 'preprocessing_fn'} + } streaming_dataset = dataset_constructor.build_from_hf( dataset_name=dataset_name_or_path, split=split, - safe_load=dataset_cfg.get('safe_load', False), - max_seq_len=dataset_cfg['max_seq_len'], preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - target_prompts=dataset_cfg.get( - 'target_prompts', - _DEFAULT_TARGET_PROMPTS, - ), - target_responses=dataset_cfg.get( - 'target_responses', - _DEFAULT_TARGET_RESPONSES, - ), - decoder_only_format=dataset_cfg['decoder_only_format'], - hf_kwargs=dataset_cfg.get('hf_kwargs', {}), + **dataset_constructor_args, ) # Ensure dataset is large enough. @@ -348,6 +374,7 @@ def _validate_config( streams: Optional[Dict[str, Any]] = None, target_prompts: Optional[str] = None, target_responses: Optional[str] = None, + allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS, **kwargs: Dict[str, Any], ) -> None: """Validates the dataset configuration. @@ -357,46 +384,59 @@ def _validate_config( the other. Args: - dataset_cfg (DictConfig): The dataset configuration to be validated. + max_seq_len (int): The maximum length of sequences + in the batch. See :class:`Seq2SeqFinetuningCollator` docstring + for details. + decoder_only_format (bool): Whether to format the + examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` + docstring for details. + hf_name (str, optional): The name of the HuggingFace dataset + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. + local (str, optional): Local path where remote data + will be streamed to. Only valid if `cfg.dataset.remote` has + also been set. + remote (str, optional): Location of a MDS-formatted + streaming dataset to use. Setting this will tell the builder + to create a streaming dataset rather than a HuggingFace dataset. + hf_kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. + preprocessing_fn (str, optional): The name/import path of + the preprocessing function to use for formatting the data examples. + If ``None`` (default), the builder will use the preprocessing function + registered under `hf_name` (see `tasks.py`), if one exists, + otherwise it will skip preprocessing. + If `preprocessing_fn` corresponds to a registered preprocessing + function in `tasks.py`, the builder will use that. + Otherwise, it will interpret `preprocessing_fn` as a + "import.path:function_name" import path; e.g., it will call + `from import.path import function_name` and use the imported + function as the preprocessing function. + safe_load (bool, optional): Whether to enforce safe loading of the dataset. + If `None`, will default to not applying any safe loading. + streams (Dict[str, Any], optional): A dictionary with multiple data streams. + If `None`, will assume no streams. + target_prompts (str): Which prompts are used as training targets. + Defaults to "none", meaning prompts are never used as training targets. + See :class:`Seq2SeqFinetuningCollator` docstring for details. + target_responses (str): Which responses are used as training targets. + Defaults to "last", meaning only the final response in multi-turn examples + will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for + details. + allowed_dataset_keys (set[str], optional): The set of allowed keys for the dataset config. + kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. Raises: ValueError: If the dataset configuration does not meet the requirements. """ - # Check for extraneous keys in the dataset config - allowed_additional_kwargs = { - 'local', - 'remote', - 'split', - 'download_retry', - 'download_timeout', - 'validate_hash', - 'keep_zip', - 'epoch_size', - 'predownload', - 'cache_limit', - 'partition_algo', - 'num_canonical_nodes', - 'batch_size', - 'shuffle', - 'shuffle_algo', - 'shuffle_seed', - 'shuffle_block_size', - 'sampling_method', - 'sampling_granularity', - 'batching_method', - 'max_seq_len', - 'allow_unsafe_types', - 'replication', - 'packing_ratio', - 'allow_pad_trimming', - 'seq_parallel_replication', - 'auto_packing_replication', - 'max_leftover_bins_to_keep', - } - if not set(kwargs.keys()).issubset(allowed_additional_kwargs): + if not set(kwargs.keys()).issubset(allowed_dataset_keys): raise ValueError( 'The dataset config contains the following extraneous keys: ' +\ - ', '.join(set(kwargs.keys()) - allowed_additional_kwargs), + ', '.join(set(kwargs.keys()) - allowed_dataset_keys), ) if hf_name is not None: @@ -480,9 +520,9 @@ def _validate_config( # Raise an error if the target_prompts + target_responses + decoder_only_format settings # are invalid if target_prompts is None: - target_prompts = _DEFAULT_TARGET_PROMPTS + target_prompts = DEFAULT_TARGET_PROMPTS if target_responses is None: - target_responses = _DEFAULT_TARGET_RESPONSES + target_responses = DEFAULT_TARGET_RESPONSES target_prompts, target_responses = target_prompts.lower( ), target_responses.lower() validate_target_settings( @@ -504,7 +544,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: completed, the function removes the signal file. Args: - hf_name (str): The path of the HuggingFace dataset to download. + remote_path (str): The path of the HuggingFace dataset to download. split (str): The dataset split to download (e.g., 'train', 'validation', 'test'). Returns: @@ -584,9 +624,9 @@ def build_collate_fn( dataset_cfg = dataloader_cfg['dataset'] target_responses = dataset_cfg.get( 'target_responses', - _DEFAULT_TARGET_RESPONSES, + DEFAULT_TARGET_RESPONSES, ) - target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS) + target_prompts = dataset_cfg.get('target_prompts', DEFAULT_TARGET_PROMPTS) max_seq_len = dataset_cfg['max_seq_len'] decoder_only_format = dataset_cfg['decoder_only_format'] allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 1b3d257b42..e8175b4446 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -47,6 +47,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -117,6 +118,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ) SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata'] +DEFAULT_TARGET_RESPONSES = 'last' +DEFAULT_TARGET_PROMPTS = 'none' PromptResponseDict = Mapping[str, str] ChatFormattedDict = Mapping[str, List[Dict[str, str]]] @@ -163,7 +166,7 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: Args: dirpath (str): Directory path to check. - Returns + Returns: True if directory is empty or non-existent. False otherwise. """ return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 @@ -800,14 +803,14 @@ def build_from_hf( self, dataset_name: str, split: str, - safe_load: bool, - max_seq_len: int, - preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], - tokenizer: PreTrainedTokenizerBase, - target_prompts: str, - target_responses: str, - decoder_only_format: bool, - hf_kwargs: Dict[str, Any], + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + target_prompts: str = DEFAULT_TARGET_PROMPTS, + target_responses: str = DEFAULT_TARGET_RESPONSES, + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -815,13 +818,45 @@ def build_from_hf( Note: This function will drop examples where the prompt is longer than the max_seq_len Args: - cfg (DictConfig): The dataset configuration. - max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped. - tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset. + dataset_name (str): The name of the HuggingFace dataset + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. + split (str): The split of the HuggingFace dataset. + safe_load (bool, optional): Whether to enforce safe loading of the dataset. + If `None`, will default to not applying any safe loading. + max_seq_len (int): The maximum length of sequences + in the batch. See :class:`Seq2SeqFinetuningCollator` docstring + for details. + preprocessing_fn (Callable, optional): The preprocessing function to use for + formatting the data examples. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenizing + the HuggingFace dataset. + target_prompts (str): Which prompts are used as training targets. + Defaults to "none", meaning prompts are never used as training targets. + See :class:`Seq2SeqFinetuningCollator` docstring for details. + target_responses (str): Which responses are used as training targets. + Defaults to "last", meaning only the final response in multi-turn examples + will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for + details. + decoder_only_format (bool): Whether to format the + examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` + docstring for details. + hf_kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. Returns: Dataset: The tokenized dataset. """ + if hf_kwargs is None: + hf_kwargs = {} + + # None is checked in the function, because argument defaults were added after the function was written and we want + # to preserve the ordering of the arguments for backwards compatibility. + if tokenizer is None: + raise ValueError('A tokenizer must be provided.') + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. @@ -908,6 +943,8 @@ def dataset_mapper(example: Dict): detected_cpu_count = os.cpu_count() or 1 detected_cpus_with_margin = detected_cpu_count - 8 num_cpus_to_use = max(1, detected_cpus_with_margin) + if len(dataset) < num_cpus_to_use: + num_cpus_to_use = 1 columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( @@ -968,12 +1005,16 @@ def dataset_mapper(example: Dict): assert filtered_dataset is not None return filtered_dataset + @property + def streaming_dataset_class(self) -> Type[StreamingFinetuningDataset]: + return StreamingFinetuningDataset + def build_from_streaming( self, *args: Any, **kwargs: Any, ) -> StreamingFinetuningDataset: - return StreamingFinetuningDataset(*args, **kwargs) + return self.streaming_dataset_class(*args, **kwargs) dataset_constructor = DatasetConstructor() diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index a6fdf34953..5579066f89 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -337,7 +337,7 @@ def auto_packing_ratio( dataloader_cfg (DictConfig): The dataloader configuration for profiling. tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. device_batch_size (int): The size of the batches (number of examples) per device. - num_packing_ratio (int): The number of packing ratios to try. + num_packing_ratios (int): The number of packing ratios to try. Returns: A packing ratio that minimizes padding while maintaining zero waste. diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index 8a8b9de551..4e49be3fba 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -251,8 +251,9 @@ def read_dataset( """ from datasets import \ Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] - from datasets import \ - load_dataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import ( # pyright: ignore[reportGeneralTypeIssues] + load_dataset, + ) if 'hf://' in dataset_uri: dataset_uri = dataset_uri.replace('hf://', '') if hf_loading_vars is None: @@ -363,6 +364,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The answer in the example @@ -712,6 +714,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The answer in from the example with chain of thought and delimiter if needed @@ -731,7 +734,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + ctxt (str): The specific example's derived context example (Dict): The example as a dictionary. Returns: @@ -1035,6 +1038,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The full string of the correct answer based on the 'gold' key @@ -1053,7 +1057,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + ctxt (str): The specific example's derived context example (Dict): The example as a dictionary. Returns: @@ -1129,6 +1133,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: since the batch may consist of multiple questions, the choice_groupings indicates which contiguous sequences of elements in the batch correspond to which question gold_indices indicates which of the [0, N-1] choices is the correct one for each question. + Args: data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) @@ -1168,6 +1173,7 @@ def split_batch(self, batch: Any, and real example, which refers to one possible continuation. As example count and microbatch_size are tracked in logical example, we split logical attributes by microbatch_size and real attributes by microbatch_size * num_choices. + Args: batch (Dict): Batch of data microbatch_size (int | float): Size of microbatches @@ -1419,7 +1425,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + context_options (str): A list of contexts for this specific example. example (Dict): The example as a dictionary. Returns: @@ -1548,6 +1554,10 @@ def partition_dataset_by_category( Args: dataset_uri (str): Location of dataset. destination_path (str): Base destination path, we will write a separate partition off this URI for each category. + hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. + Raises: MissingConditionalImportError: If datasets not installed raise exception. @@ -1643,8 +1653,7 @@ def get_icl_task_dataloader( # At this point, hf_model is randomly initialized composer_model = HuggingFaceModel(hf_model, hf_tokenizer) - Example: - + Example: .. testcode:: @@ -1685,8 +1694,8 @@ def get_icl_task_dataloader( hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. - kwargs (Dict[str, Any], default=None): Dictionary containing a mapping - from ICL dataset constructor's parameter names and their desired values. + destination_path: Where the dataloader will be saved. + kwargs (Dict[str, Any], default=None): Dictionary containing a mapping from ICL dataset constructor's parameter names and their desired values. Returns: DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py index 1ce249437d..c19ae15dd9 100644 --- a/llmfoundry/eval/datasets/utils.py +++ b/llmfoundry/eval/datasets/utils.py @@ -130,7 +130,7 @@ def make_padded_input( Args: context_enc (List): The encoded input to the model continuation_enc (List): The encoded desired output for the example - max_seq_list (int): Maximum length sequences can be + max_seq_len (int): Maximum length sequences can be pad_tok_id (int): The token id we pad with padding_side (str): Which side to pad the context on. Can be 'right' or 'left diff --git a/llmfoundry/eval/metrics/nlp.py b/llmfoundry/eval/metrics/nlp.py index 3ee30ebf5e..f0fbba3ece 100644 --- a/llmfoundry/eval/metrics/nlp.py +++ b/llmfoundry/eval/metrics/nlp.py @@ -80,7 +80,7 @@ def update( Args: batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed to compute the metric. - output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` + outputs (torch.Tensor): The model outputs evaluated on the batch `input_ids`. labels (torch.Tensor): The correct outputs. Raises: diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 536cd0257d..f1f38e2f7d 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -11,7 +11,6 @@ Any, Dict, List, - Mapping, Optional, Tuple, Union, @@ -23,7 +22,6 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, - PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -36,7 +34,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import get_hf_config_value +from llmfoundry.utils.config_utils import set_config_overrides if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -105,9 +103,13 @@ def __init__( config_overrides=config_overrides, load_in_8bit=load_in_8bit, pretrained=pretrained, - prepare_for_fsdp=True, + prepare_for_fsdp=False, ) + model = self.transform_model(model) + + ComposerHFCausalLM.prepare_inner_model(model, init_device) + train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( use_train_metrics=use_train_metrics, additional_train_metrics=additional_train_metrics, @@ -121,7 +123,7 @@ def __init__( peft_config_object = None if peft_config is not None: - peft_config_object = self._get_peft_config(peft_config) + peft_config_object = self.get_peft_config(peft_config) # Set up config args for the model construction and base classes super().__init__( @@ -135,6 +137,17 @@ def __init__( should_save_peft_only=should_save_peft_only, ) + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + """Transforms the model after initialization. + + Args: + model (PreTrainedModel): The model to transform. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + @staticmethod def build_metrics( use_train_metrics: bool, @@ -192,6 +205,7 @@ def build_inner_model( use_auth_token (bool): Whether to use an authentication token. config_overrides (Dict[str, Any]): The configuration overrides. load_in_8bit (bool): Whether to load in 8-bit. + pretrained (bool): Whether the model is pretrained. prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False. Returns: @@ -216,6 +230,22 @@ def build_inner_model( + 'Please `pip install llm-foundry[gpu]`.', ) + # Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the config on local rank + # zero. This will set up the transformers module cache and avoid the future contention. + if dist.get_local_rank() == 0: + AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 + ) + + dist.barrier() + # Construct the Hugging Face config to use config = AutoConfig.from_pretrained( pretrained_model_name_or_path, @@ -226,67 +256,7 @@ def build_inner_model( False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 ) - # This is not ideal, however Hugging Face's _autoset_attn_implementation function - # forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading - # the model and then casting it back to fp32, we are monkeypatching their check. - # https://github.com/huggingface/transformers/issues/28052 - def _autoset_attn_implementation_monkeypatch( - cls, # type: ignore - config, # type: ignore - *args, # type: ignore - **kwargs, # type: ignore - ): # type: ignore - config._attn_implementation = requested_attention_implementation - return config - - PreTrainedModel._autoset_attn_implementation = classmethod( - _autoset_attn_implementation_monkeypatch, - ) - - # set config overrides - for k, v in config_overrides.items(): - if not hasattr(config, k): - raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).', - ) - - attr = getattr(config, k) - # attempt to disallow typos in nested configs - if isinstance(attr, Mapping): - extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] - if extra_keys: - raise ValueError( - f'Config dict override got unknown keys. ' + - f'Extra keys: {extra_keys}. ' + - f'Expected (a subset of) keys: {list(attr.keys())}.', - ) - getattr(config, k).update(v) - # necessary case to allow for rope_scaling to be overriden in llama config - elif attr is None and isinstance(v, Mapping): - setattr(config, k, {}) - getattr(config, k).update(v) - elif isinstance(attr, PretrainedConfig): - if not isinstance(v, Mapping): - raise ValueError( - f'Expected a dictionary for config override {k}, but got {v}.', - ) - - for _k, _v in v.items(): - if not hasattr(attr, _k): - raise ValueError( - f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', - ) - setattr(attr, _k, _v) - else: - setattr(config, k, v) - - if hasattr(config, 'attn_config') and get_hf_config_value( - config.attn_config, - 'seq_parallel_world_size', - ) is not None: - raise NotImplementedError( - 'Sequence Parallelism is not supported for HuggingFace models.', - ) + set_config_overrides(config, config_overrides) # We need to have all non-zero local ranks be not-pretrained # Rank 0 will still be pretrained, and distribute the weights appropriately @@ -298,7 +268,7 @@ def _autoset_attn_implementation_monkeypatch( # the different processes. To avoid this contention, we first create the model (on meta device) on local rank # zero. This will set up the transformers model cache and avoid the future contention. if dist.get_local_rank() == 0: - if os.path.isdir(pretrained_model_name_or_path): + if pretrained and os.path.isdir(pretrained_model_name_or_path): with init_empty_weights(include_buffers=False): with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) @@ -306,6 +276,8 @@ def _autoset_attn_implementation_monkeypatch( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, + attn_implementation= + requested_attention_implementation, config=config, ) else: @@ -313,6 +285,7 @@ def _autoset_attn_implementation_monkeypatch( AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, ) dist.barrier() @@ -325,12 +298,14 @@ def _autoset_attn_implementation_monkeypatch( trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, load_in_8bit=load_in_8bit, + attn_implementation=requested_attention_implementation, config=config, ) else: model = AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, ) elif resolved_init_device == 'meta': if pretrained: @@ -341,6 +316,7 @@ def _autoset_attn_implementation_monkeypatch( model = AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, ) else: raise ValueError( @@ -379,10 +355,10 @@ def _autoset_attn_implementation_monkeypatch( if prepare_for_fsdp: ComposerHFCausalLM.prepare_inner_model(model, init_device) + return model - @staticmethod - def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': + def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: from peft import LoraConfig peft_type = peft_config_dict.get('peft_type', '') diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index dde7d64cd7..3e365edc47 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -411,9 +411,11 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -426,6 +428,7 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.qk_gn = qk_gn + self.fused_qkv = fused_qkv self.d_model = d_model self.n_heads = n_heads @@ -462,7 +465,17 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - if self.reuse_kv_layer_idx is None: + if self.reuse_kv_layer_idx is not None: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) + elif self.fused_qkv: self.Wqkv = build_fc( name=fc_type_name, in_features=self.d_model, @@ -482,15 +495,33 @@ def __init__( out_features=self.d_model, fc_kwargs=fc_type, ) + self.Wk = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + self.Wv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) # for param init fn; enables shape based init of fused layers - fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] - self.Wq._fused = (0, fuse_splits) + q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + kv_fuse_splits = [ + i * self.head_dim for i in range(1, self.kv_n_heads) + ] + self.Wq._fused = (0, q_fuse_splits) + self.Wk._fused = (0, kv_fuse_splits) + self.Wv._fused = (0, kv_fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model self.q_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) if self.reuse_kv_layer_idx is None: @@ -499,6 +530,7 @@ def __init__( self.k_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) @@ -574,6 +606,7 @@ def get_qkv( Args: x (torch.Tensor): The input tensor. + prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. Returns: query (torch.Tensor): The query tensor. @@ -601,19 +634,29 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) return query, key, value - qkv = self.Wqkv(x) + if self.fused_qkv: + qkv = self.Wqkv(x) - if self.clip_qkv: - qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + if self.clip_qkv: + qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query, key, value = qkv.split( + [ + self.d_model, + self.kv_n_heads * self.head_dim, + self.kv_n_heads * self.head_dim, + ], + dim=2, + ) + else: + query = self.Wq(x) + key = self.Wk(x) + value = self.Wv(x) - query, key, value = qkv.split( - [ - self.d_model, - self.kv_n_heads * self.head_dim, - self.kv_n_heads * self.head_dim, - ], - dim=2, - ) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv) + value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv) if self.qk_ln or self.qk_gn: # Applying layernorm to qk @@ -753,9 +796,11 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -770,9 +815,11 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, @@ -796,9 +843,11 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -813,9 +862,11 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c6988b7bd7..92735cc489 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,6 +42,7 @@ def __init__( ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, no_bias: bool = False, @@ -84,6 +85,7 @@ def __init__( fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, + norm_eps=norm_eps, device=device, no_bias=no_bias, ) @@ -99,6 +101,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -117,6 +120,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) @@ -260,6 +264,7 @@ def __init__( fc_type: Optional[dict[str, Any]] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, device: Optional[str] = None, no_bias: bool = False, **kwargs: Any, @@ -283,6 +288,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -302,6 +308,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a28725ee0f..f5d6d67040 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -53,6 +53,19 @@ } +def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: + """Applies GELU approximation that is fast but somewhat inaccurate. + + Args: + input (torch.Tensor): Input tensor of shape(*), where * means any + number of dimensions + + Returns: + torch.Tensor: Tensor with same shape as input tensor + """ + return input * torch.sigmoid(1.702 * input) + + def resolve_ffn_act_fn( config: Optional[dict] = None, ) -> Callable[[torch.Tensor], torch.Tensor]: @@ -70,10 +83,13 @@ def resolve_ffn_act_fn( config = _FFN_ACT_FN_DEFAULT config = deepcopy(config) name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognized activation function name ({name}).') - act = getattr(torch.nn.functional, name) - return partial(act, **config) + if name == 'quick_gelu': + return quickgelu_activation + else: + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognized activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) @@ -413,6 +429,7 @@ def set_ffn_device_mesh( ffn (nn.Module): The FFN module. moe_world_size (int): The MoE world size. device_mesh (DeviceMesh): The full device mesh. + get_fsdp_submesh (Callable[[DeviceMesh], DeviceMesh]): A function to get the fsdp submesh. Raises: RuntimeError: If the device mesh is 3D. diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 69d2059bad..d5fd1d37d4 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -26,10 +26,12 @@ def build_norm( name: str, normalized_shape: Union[int, List[int], torch.Size], + eps: Optional[float] = 1e-5, device: Optional[str] = None, ): kwargs = { 'normalized_shape': normalized_shape, + 'eps': eps, 'device': device, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index a1fdc25f50..9671eb6ed5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -44,6 +44,7 @@ def __init__( no_bias: bool = False, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, use_cache: bool = False, init_config: Optional[Dict] = None, fc_type: Union[str, Dict] = 'torch', @@ -70,6 +71,8 @@ def __init__( attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer. + fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single + Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, @@ -99,6 +102,7 @@ def __init__( no_bias (bool): Whether to use bias in all layers. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use + norm_eps (float): epsilon value for norm layer use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', @@ -143,6 +147,7 @@ def __init__( reuse_kv_layer: attn_config: reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse + kwargs (Any): Other relevant keyword arguments. """ self.d_model = d_model self.n_heads = n_heads @@ -166,6 +171,7 @@ def __init__( self.no_bias = no_bias self.embedding_fraction = embedding_fraction self.norm_type = norm_type + self.norm_eps = norm_eps self.use_cache = use_cache self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, @@ -304,6 +310,7 @@ def _validate_config(self) -> None: 'no_scaling', 'linear', 'dynamic', + 'llama3', ]: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3b2744f867..6f9b6bf806 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -49,12 +49,10 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaConfig, + LlamaRotaryEmbedding, +) from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import ( @@ -88,14 +86,62 @@ log = logging.getLogger(__name__) +class InvalidConfigAccessError(KeyError): + pass + + +_ALLOWED_LLAMA_CONFIG_KEYS = { + # These are the only config keys that are set and are safe to read from + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + + # Not set but llama modeling code tries to read this attribute + 'partial_rotary_factor', + + # Benign transformers attributes needed for __init__ + '_get_generation_defaults', + 'label2id', + 'id2label', + 'torch_dtype', + 'problem_type', + '__class__', +} + + +class PartialLlamaConfig(LlamaConfig): + """Holds the rope config for Llama models and throws. + + an `InvalidConfigAccessError` if any other config elements are read. This + class is necessary because the `LlamaRotaryEmbedding` class takes a full + `LlamaConfig` now instead of the old keyword arguments. + """ + + def __getattribute__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getattribute__(key) + + def __getitem__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getitem__(key) + + def gen_rotary_embedding( - rope_head_dim: int, rope_impl: str, rope_theta: int, rope_dail_config: dict, rope_hf_config: dict, max_seq_len: int, + d_model: int, + n_heads: int, ): + rope_head_dim = d_model // n_heads if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -108,32 +154,21 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_impl == 'hf': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = llama_rope_config.pop('type') + if llama_rope_config['rope_type'] == 'no_scaling': + llama_rope_config['rope_type'] = 'default' + partial_llama_config = PartialLlamaConfig( + rope_scaling=llama_rope_config, + rope_theta=rope_theta, + max_position_embeddings=max_seq_len, + hidden_size=d_model, + num_attention_heads=n_heads, + ) if rope_hf_config['type'] == 'no_scaling': - return HFRotaryEmbeddingFoundry( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) + return LlamaRotaryEmbeddingFoundry(config=partial_llama_config) + elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: + return LlamaRotaryEmbedding(config=partial_llama_config) raise ValueError('rope_impl needs to be either dail or hf') @@ -306,7 +341,7 @@ def apply_sequence_id( return attn_bias -class HFRotaryEmbeddingFoundry(HFRotaryEmbedding): +class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding): @torch.no_grad() def forward( @@ -391,6 +426,7 @@ def __init__(self, config: MPTConfig): self.norm_f = build_norm( name=config.norm_type.lower(), normalized_shape=config.d_model, + eps=config.norm_eps, device=config.init_device, ) @@ -399,12 +435,13 @@ def __init__(self, config: MPTConfig): if self.rope: self.rope_impl = config.attn_config['rope_impl'] self.rotary_embedding = gen_rotary_embedding( - rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len, + d_model=config.d_model, + n_heads=config.n_heads, ) if config.init_device != 'meta': @@ -1285,6 +1322,40 @@ def _reorder_cache( return reordered_past +def get_targets(labels: torch.Tensor) -> torch.Tensor: + targets = torch.roll(labels, shifts=-1) + targets[:, -1] = -100 + return targets + + +def compute_loss_from_logits( + outputs: CausalLMOutputWithPast, + shift_labels: bool, + labels: torch.Tensor, + loss_fn: nn.Module, + sample_weighing_factor: Optional[torch.Tensor] = None, +) -> torch.Tensor: + targets = get_targets(labels) if shift_labels else labels + + losses = loss_fn( + outputs.logits.view(-1, outputs.logits.size(-1)), + targets.view(-1), + ) + + if torch.all(targets == loss_fn.ignore_index): + loss = losses.sum() + else: + loss = losses.sum() / (targets != loss_fn.ignore_index).sum() + if sample_weighing_factor is not None: + if sample_weighing_factor.shape[0] > 1: + raise ValueError( + 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', + ) + loss = loss * sample_weighing_factor[0].item() + + return loss + + class ComposerMPTCausalLM(HuggingFaceModel): def __init__( @@ -1362,9 +1433,7 @@ def config_class(self) -> Type[MPTConfig]: return MPTConfig def get_targets(self, batch: Mapping) -> torch.Tensor: - targets = torch.roll(batch['labels'], shifts=-1) - targets[:, -1] = -100 - return targets + return get_targets(batch['labels']) def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: @@ -1385,27 +1454,14 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - if self.shift_labels: - targets = self.get_targets(batch) - else: - targets = batch['labels'] - - losses = self.loss_fn( - outputs.logits.view(-1, outputs.logits.size(-1)), - targets.view(-1), + loss = compute_loss_from_logits( + outputs, + self.shift_labels, + batch['labels'], + self.loss_fn, + batch.get('sample_weighing_factor', None), ) - if torch.all(targets == self.loss_fn.ignore_index): - loss = losses.sum() - else: - loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if 'sample_weighing_factor' in batch: - if batch['sample_weighing_factor'].shape[0] > 1: - raise ValueError( - 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', - ) - loss = loss * batch['sample_weighing_factor'][0].item() - if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors @@ -1420,7 +1476,6 @@ def loss(self, outputs: CausalLMOutputWithPast, 'loss': loss, 'lbl': lbl, } - return loss @cached_property diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 2b6fc2f7c7..c272a52dd4 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -15,6 +15,7 @@ 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, + 'fused_qkv': True, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 50481211ac..3f0163ff01 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -6,6 +6,7 @@ from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim import ComposerScheduler +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader from torch.utils.data import Dataset @@ -154,6 +155,19 @@ description=_schedulers_description, ) +_tokenizers_description = ( + 'The tokenizers registry is used to register tokenizers that implement the transformers.PreTrainedTokenizerBase interface. ' + + + 'The tokenizer will be passed to the build_dataloader() and build_composer_model() methods in train.py.' +) +tokenizers = create_registry( + 'llmfoundry', + 'tokenizers', + generic_type=Type[PreTrainedTokenizerBase], + entry_points=True, + description=_tokenizers_description, +) + _models_description = ( """The models registry is used to register classes that implement the ComposerModel interface. @@ -339,6 +353,42 @@ description=_config_transforms_description, ) +_load_planners_description = ( + """The load_planners registry is used to register classes that implement the LoadPlanner interface. + + The LoadPlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to load distributed checkpoints. + + Returns: + LoadPlanner: The load planner. + """ +) + +load_planners = create_registry( + 'llmfoundry', + 'load_planners', + generic_type=Type[LoadPlanner], + entry_points=True, + description=_load_planners_description, +) + +_save_planners_description = ( + """The save_planners registry is used to register classes that implement the SavePlanner interface. + + The savePlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to save distributed checkpoints. + + Returns: + SavePlanner: The save planner. + """ +) + +save_planners = create_registry( + 'llmfoundry', + 'save_planners', + generic_type=Type[SavePlanner], + entry_points=True, + description=_save_planners_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -346,6 +396,7 @@ 'optimizers', 'algorithms', 'schedulers', + 'tokenizers', 'models', 'dataset_replication_validators', 'collators', @@ -363,4 +414,6 @@ 'fcs', 'icl_datasets', 'config_transforms', + 'load_planners', + 'save_planners', ] diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index 1703ed8862..d37c12a555 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -1,8 +1,11 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) + __all__ = [ 'TiktokenTokenizerWrapper', ] diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index f087664344..fd0fc5948a 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -90,6 +90,7 @@ def __init__( errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. Defaults to `"replace"`. + kwargs (Any): Other relevant keyword arguments. """ try: import tiktoken diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 012a0b704f..a1d84601b3 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -27,6 +27,7 @@ from composer.utils import dist from omegaconf import DictConfig from omegaconf import OmegaConf as om +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -34,9 +35,9 @@ from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.eval.datasets.in_context_learning_evaluation import \ - get_icl_task_dataloader -from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.eval.datasets.in_context_learning_evaluation import ( + get_icl_task_dataloader, +) from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry @@ -126,8 +127,7 @@ def build_eval_loaders( # Load the eval data to fail fast. metrics will get added # later in add_metrics_to_eval_loaders, after the model is loaded metric_names=[], - # TODO: Fix type in Composer - device_eval_microbatch_size=device_eval_batch_size, # type: ignore + device_eval_microbatch_size=device_eval_batch_size, ) evaluators.append(eval_loader) return evaluators @@ -187,6 +187,46 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb +def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner: + """Builds a load planner from the registry. + + Args: + name (str): Name of the load planner to build. + kwargs (Any): Other relevant keyword arguments. + + Returns: + LoadPlanner: The load planner. + """ + return construct_from_registry( + name=name, + registry=registry.load_planners, + partial_function=True, + pre_validation_function=LoadPlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + +def build_save_planner(name: str, **kwargs: Any) -> SavePlanner: + """Builds a save planner from the registry. + + Args: + name (str): Name of the save planner to build. + kwargs (Any): Other relevant keyword arguments. + + Returns: + savePlanner: The save planner. + """ + return construct_from_registry( + name=name, + registry=registry.save_planners, + partial_function=True, + pre_validation_function=SavePlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + def build_composer_model( name: str, cfg: Dict[str, Any], @@ -467,8 +507,15 @@ def build_tokenizer( with dist.local_rank_zero_download_and_wait(signal_file_path): pass - if tokenizer_name.startswith('tiktoken'): - tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) + if tokenizer_name in registry.tokenizers: + tokenizer = construct_from_registry( + name=tokenizer_name, + registry=registry.tokenizers, + partial_function=True, + pre_validation_function=PreTrainedTokenizerBase, + post_validation_function=None, + kwargs=tokenizer_kwargs, + ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 905afd6edb..5c65a7475e 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -177,6 +177,7 @@ def _convert_weight_to_ft_each( tensor_name (str): Name of the weight tensor. Used in naming the output file. config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters. data (np.ndarray): Tensor data in np.ndarray format. + np_weight_data_type (np.dtype): Data type of the numpy array `data`. Returns: None: Writes to a file in `save_dir`. File name is based on the `tensor_name` diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 4b86de99b8..eb54fabc3d 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -167,6 +167,7 @@ class TrainConfig: # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto' global_train_batch_size: Optional[int] = None + spin_dataloaders: bool = True # Eval dataloader eval_subset_num_batches: int = -1 @@ -331,16 +332,14 @@ def make_dataclass_and_log_config( 'icl_tasks must be specified in the config', ) - # Create copy of config for logging - logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config) - # Apply transforms to the unstructured config before constructing dataclass unstructured_config = apply_transforms_to_config( unstructured_config, transforms, ) - logged_cfg.update(unstructured_config, merge=True) + # Create copy of config for logging + logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config) arg_config_keys = set(unstructured_config.keys()) extraneous_keys = set.difference(arg_config_keys, dataclass_fields) @@ -467,6 +466,9 @@ def update_config_with_batch_size_info( Args: cfg (Dict[str, Any]): The config to update. + device_train_batch_size (Union[int, float]): The batch size of the training dataset for each device. + device_train_microbatch_size (Union[int, float, Literal['auto']]): The microbatch size of the training dataset for each device. + device_train_grad_accum (Union[int, Literal['auto']]): The gradient accumulation settings for each device. Returns: Dict[str, Any]: The updated config. @@ -531,7 +533,6 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): fsdp_config['sync_module_states'] = True # Set defaults for mixed initialization - fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) # Set ffn_config.device_mesh to fsdp_config.device_mesh @@ -812,3 +813,45 @@ def _verify_uc_path(path: str) -> bool: f'but your `UCVolumeDatasetSource` might be invalid.', ) return False + + +def set_config_overrides( + config: PretrainedConfig, + config_overrides: Dict[str, Any], +): + # set config overrides + for k, v in config_overrides.items(): + if not hasattr(config, k): + raise ValueError( + f'config does not have attribute "{k}" to override ({k}: {v}).', + ) + + attr = getattr(config, k) + # attempt to disallow typos in nested configs + if isinstance(attr, Mapping): + extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] + if extra_keys: + raise ValueError( + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + + f'Expected (a subset of) keys: {list(attr.keys())}.', + ) + getattr(config, k).update(v) + # necessary case to allow for rope_scaling to be overriden in llama config + elif attr is None and isinstance(v, Mapping): + setattr(config, k, {}) + getattr(config, k).update(v) + elif isinstance(attr, PretrainedConfig): + if not isinstance(v, Mapping): + raise ValueError( + f'Expected a dictionary for config override {k}, but got {v}.', + ) + + for _k, _v in v.items(): + if not hasattr(attr, _k): + raise ValueError( + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', + ) + setattr(attr, _k, _v) + else: + setattr(config, k, v) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 251bc86631..c6a667697d 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -28,6 +28,7 @@ 'InputFolderMissingDataError', 'OutputFolderNotEmptyError', 'MisconfiguredHfDatasetError', + 'DatasetTooSmallError', 'RunTimeoutError', ] @@ -364,6 +365,14 @@ def __init__(self, dataset_name: str, split: str) -> None: super().__init__(message, dataset_name=dataset_name, split=split) +class DatasetTooSmallError(UserError): + """Error thrown when the dataset is too small to be processed.""" + + def __init__(self) -> None: + message = f'Your dataset is too small and produced no complete samples during preprocessing. Please provide more data.' + super().__init__(message) + + class RunTimeoutError(InternalError): """Error thrown when a run times out.""" diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index dde8240d8b..9609982fda 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -69,7 +69,7 @@ def download_from_hf_hub( Safetensors weights will be downloaded unless `prefer_safetensors` is set to False. Args: - repo_id (str): The Hugging Face Hub repo ID. + model (str): The Hugging Face Hub repo ID. save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. @@ -157,7 +157,7 @@ def _recursive_download( Args: session: A requests.Session through which to make requests to the remote server. - url (str): The base URL where the files are located. + base_url (str): The base URL where the files are located. path (str): The path from the base URL to the files to download. The full URL for the download is equal to '/'. save_dir (str): The directory to save downloaded files to. diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 3ea7cc58a7..f96e72b3a2 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -127,6 +127,7 @@ def construct_from_registry( before constructing the item to return. This should throw an exception if validation fails. Defaults to None. post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after constructing the item to return. This should throw an exception if validation fails. Defaults to None. + kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. Raises: ValueError: If the validation functions failed or the registered item is invalid @@ -176,6 +177,7 @@ def import_file(loc: Union[str, Path]) -> ModuleType: """Import module from a file. Used to run arbitrary python code. + Args: name (str): Name of module to load. loc (str / Path): Path to the file. diff --git a/mcli/mcli-1b-eval.yaml b/mcli/mcli-1b-eval.yaml index d8ef42d5d5..4bfa301f8e 100644 --- a/mcli/mcli-1b-eval.yaml +++ b/mcli/mcli-1b-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index 4b8eb601b2..2dc83d36a9 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-1b.yaml b/mcli/mcli-1b.yaml index 20d41fb2cc..69b2295011 100644 --- a/mcli/mcli-1b.yaml +++ b/mcli/mcli-1b.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-benchmark-mpt.yaml b/mcli/mcli-benchmark-mpt.yaml index a7d44239b5..7a3ea2cbe9 100644 --- a/mcli/mcli-benchmark-mpt.yaml +++ b/mcli/mcli-benchmark-mpt.yaml @@ -11,7 +11,7 @@ image: mosaicml/llm-foundry:2.3.1_cu121-latest integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] diff --git a/mcli/mcli-convert-composer-to-hf.yaml b/mcli/mcli-convert-composer-to-hf.yaml index 60c708e2f6..fefaf8e1a3 100644 --- a/mcli/mcli-convert-composer-to-hf.yaml +++ b/mcli/mcli-convert-composer-to-hf.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index e69e6dadda..e58d42483a 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-hf-generate.yaml b/mcli/mcli-hf-generate.yaml index 8b382c41f0..02c49d84c3 100644 --- a/mcli/mcli-hf-generate.yaml +++ b/mcli/mcli-hf-generate.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 443429ca0a..47c163faf8 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index 38b02a6019..c372014165 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: .[gpu,openai] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index 0749dcc86e..a4496503cd 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -14,7 +14,7 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.8.0 + git_branch: v0.10.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo diff --git a/pyproject.toml b/pyproject.toml index 53007cafaf..fdbabfff96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,23 +11,44 @@ skip = [ "env", "wandb", "runs", "build", "node_modules" ] include_trailing_comma = true split_on_trailing_comma = true +# Ruff global +[tool.ruff] +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] + +# Ruff linter [tool.ruff.lint] select = [ "C4", - # TODO port pydocstyle - # "D", # pydocstyle "LOG", "PERF", "PLE", "COM812", + "D", # pydocstyle ] -[tool.ruff] -exclude = [ - "build/**", - "docs/**", - "node_modules/**", + +extend-select = ["D404"] # pydocstyle + +ignore = [ + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", + "D400", + "D401", + "D415", ] +[tool.ruff.lint.pydocstyle] +convention = "google" + + # Coverage [tool.coverage.run] parallel = true @@ -79,7 +100,7 @@ reportMissingImports = "none" # Pytest [tool.pytest.ini_options] # By default, skip gpu tests -addopts = "--tb=short -m 'not gpu'" +addopts = "--tb=short -m 'not gpu' --color=yes" markers = [ # For distributed testing @@ -506,8 +527,3 @@ ignore_patterns = [ "wandb/**/*.py", "build/**/*.py", ] - -[tool.pydocstyle] -convention="google" -add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" -add_select="D404" diff --git a/scripts/data_prep/convert_dataset_json.py b/scripts/data_prep/convert_dataset_json.py index 37b0465692..5a6927ac75 100644 --- a/scripts/data_prep/convert_dataset_json.py +++ b/scripts/data_prep/convert_dataset_json.py @@ -2,24 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """Streaming dataset conversion scripts for json files.""" -import os from argparse import ArgumentParser, Namespace -from enum import Enum -from glob import glob -from typing import Optional -import datasets as hf_datasets -from streaming import MDSWriter -from torch.utils.data import IterableDataset -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from llmfoundry.data import ConcatTokensDataset, NoConcatDataset - - -class ConcatMode(Enum): - NO_CONCAT = 'NO_CONCAT' - CONCAT_TOKENS = 'CONCAT_TOKENS' +from llmfoundry.command_utils import convert_dataset_json_from_args def parse_args() -> Namespace: @@ -46,145 +31,19 @@ def parse_args() -> Namespace: parser.add_argument('--no_wrap', default=False, action='store_true') parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.split)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - # Make sure we have needed concat options - if ( - parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None - ): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer', - ) - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -def build_hf_dataset( - path: str, - split: str, - mode: ConcatMode, - max_length: Optional[int] = None, - bos_text: str = '', - eos_text: str = '', - no_wrap: bool = False, - tokenizer: PreTrainedTokenizerBase = None, -) -> IterableDataset: - """Build an IterableDataset over the HF C4 or pile source data. - - Args: - dataset_name (str): Dataset name - split (str): Split name. - mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS - max_length (int): The length of concatenated tokens - bos_text (str): text to insert at the beginning of each sequence - eos_text (str): text to insert at the end of each sequence - no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries - tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use - data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). - - Returns: - An IterableDataset. - """ - if os.path.isdir(path): - data_files = glob(f'{path}/*') - else: - data_files = path - - hf_dataset = hf_datasets.load_dataset( - 'json', - data_files=data_files, - split=split, - ) - - if mode == ConcatMode.NO_CONCAT: - dataset = NoConcatDataset(hf_dataset) - else: - if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase', - ) - if max_length is None: - raise ValueError(f'max_length must be set.') - if bos_text + eos_text == '': - test_tokens = tokenizer('test') - if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: - tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' - tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' - tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' - tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' - tok_error_msg += '--bos_text=<|endoftext|>.' - raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset( - hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap, - ) - return dataset - - -def main(args: Namespace) -> None: - """Main: create C4/pile streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - if args.concat_tokens is not None: - mode = ConcatMode.CONCAT_TOKENS - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - # we will enforce length, so suppress warnings about sequences too long for the model - tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'ndarray:int32'} - else: - mode = ConcatMode.NO_CONCAT - tokenizer = None - columns = {'text': 'str'} - - # Get samples - dataset = build_hf_dataset( +if __name__ == '__main__': + args = parse_args() + convert_dataset_json_from_args( path=args.path, + out_root=args.out_root, + compression=args.compression, + concat_tokens=args.concat_tokens, split=args.split, - mode=mode, - max_length=args.concat_tokens, + tokenizer=args.tokenizer, bos_text=args.bos_text, eos_text=args.eos_text, no_wrap=args.no_wrap, - tokenizer=tokenizer, - ) - - print('here') - - # Write samples - print(f'Converting to MDS format...') - print( - f'Note that the progress bar is based on the dataset length before tokenization.', ) - print(f'It will finish at a value below 100% if tokenizing') - with MDSWriter( - columns=columns, - out=os.path.join(args.out_root), - compression=args.compression, - ) as out: - for sample in tqdm(dataset): - out.write(sample) - - -if __name__ == '__main__': - main(parse_args()) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 3b88ba668f..277a8c1ffc 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -4,41 +4,12 @@ import logging import os import re -import time -import urllib.parse -from argparse import ArgumentParser, Namespace -from collections import namedtuple -from concurrent.futures import ProcessPoolExecutor -from typing import Iterable, List, Optional, Tuple, Union -from uuid import uuid4 +from argparse import ArgumentParser -import google.protobuf.any_pb2 as any_pb2 -import lz4.frame -import pandas as pd -import pyarrow as pa -import pyspark.sql.connect.proto as pb2 -import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 -import requests -from composer.utils import retry -from databricks import sql -from databricks.connect import DatabricksSession -from databricks.sdk import WorkspaceClient from databricks.sql.client import Connection as Connection from databricks.sql.client import Cursor as Cursor -from packaging import version -from pyspark.sql import SparkSession -from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import \ - ExecutePlanResponseReattachableIterator -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.dataframe import DataFrame as SparkDataFrame -from pyspark.sql.types import Row -from llmfoundry.utils.exceptions import ( - ClusterDoesNotExistError, - FailedToConnectToDatabricksError, - FailedToCreateSQLConnectionError, -) +from llmfoundry.command_utils import convert_delta_to_json_from_args MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' @@ -47,617 +18,6 @@ log = logging.getLogger(__name__) -Result = namedtuple( - 'Result', - [ - 'url', - 'row_count', - 'compressed_size', - 'uncompressed_size', - ], -) # pyright: ignore - -# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. -# It allows the client to fetch the results in different formats from the server. -# To be able to use the code make sure this module is not overriden by DB Connect classes. - - -def to_cf(self: SparkConnectClient, - plan: pb2.Plan, - type: str = 'json') -> Tuple[List[Result], int, bool]: - """Executes the query plans and return as presigned URLS for cloud fetch. - - It can handle the current output formats that are supported by the server. - In contrast to the regular API methods of the client, this method does not - return the schema and drops all other responses. - - Args: - plan (pb2.Plan): The plan object to be executed by spark. - type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. - - Returns: - Tuple[List[Result], int, bool]: A tuple containing: - - A list of Result namedtuples, each containing a URL, row count, compressed size, - and uncompressed size of the part of the result. - - Total row count of all parts of the result. - - A boolean indicating whether the result has been truncated. - """ - log.info(f'Executing query plan with format: {type}') - - req = self._execute_plan_request_with_metadata() - req.plan.CopyFrom(plan) - - # Add the request options - if type == 'json': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON - elif type == 'csv': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV - elif type == 'arrow': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW - else: - raise ValueError( - f'Only formats json, csv, and arrow are supported. Got invalid type {type}', - ) - - ro = cloud_pb2.ResultOptions( - type=cloud_pb2.ResultOptions.TYPE_CLOUD, - cloudOptions=cloud_pb2.ResultOptions.CloudOptions( - format=format, - useCompression=False, - ), - ) - cloud_option = any_pb2.Any() - cloud_option.Pack(ro) - req.request_options.append( - pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), - ) - - # Create the iterator - iterator = ExecutePlanResponseReattachableIterator( - req, - self._stub, - self._retry_policy, - self._builder.metadata(), - ) - # Iterate over the response - result = [] - row_count = 0 - is_overflow = False - - for response in iterator: - if response.HasField('extension') and response.extension.Is( - cloud_pb2.CloudResultBatch.DESCRIPTOR, - ): - batch = cloud_pb2.CloudResultBatch() - if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): - raise ValueError( - 'Response extension is not of type CloudResultBatch.', - ) - response.extension.Unpack(batch) - result += [ - Result( - b.url, - b.row_count, - b.compressed_size, - b.uncompressed_size, - ) for b in batch.results - ] - row_count += sum(result.row_count for result in batch.results) - is_overflow |= batch.truncated - return result, row_count, is_overflow - - -SparkConnectClient.to_cf = to_cf # pyright: ignore - - -def collect_as_cf(self: DataFrame, - type: str = 'json') -> Tuple[List[Result], int, bool]: - """Collects DataFrame execution plan as presigned URLs. - - This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the - execution plan of the current DataFrame, converts it to a protocol buffer format, and then - uses the `to_cf` method to execute the plan and fetch results as presigned URLs. - - Args: - type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. - - Returns: - Tuple[List[Result], int, bool]: A tuple containing: - - A list of Result namedtuples, each containing a URL, row count, compressed size, - and uncompressed size of the part of the result. - - Total row count of all parts of the result. - - A boolean indicating whether the result is truncated or overflowed. - """ - log.info(f'Collecting DataFrame as cloud fetch with format: {type}') - query = self._plan.to_proto(self._session.client) # pyright: ignore - return self._session.client.to_cf(query, type) # pyright: ignore - - -DataFrame.collect_cf = collect_as_cf # pyright: ignore - - -def iterative_combine_jsons(json_directory: str, output_file: str) -> None: - """Combine jsonl files in json_directory into one big jsonl file. - - This function does not work for nested subdirectories. - - Args: - json_directory(str): directory containing the JSONL files - output_file(str): path to the output combined JSONL file - """ - log.info( - f'Starting to combine JSON files from {json_directory} into {output_file}', - ) - json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] - log.info(f'Found {len(json_files)} JSON files to combine') - with open(output_file, 'w') as outfile: - for file_name in json_files: - log.debug(f'Processing file: {file_name}') - with open(os.path.join(json_directory, file_name), 'r') as infile: - for line in infile: - outfile.write(line) - log.info('JSON files have been successfully combined into a JSONL file.') - - -def run_query( - query: str, - method: str, - cursor: Optional[Cursor] = None, - spark: Optional[SparkSession] = None, - collect: bool = True, -) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: - """Run SQL query via databricks-connect or databricks-sql. - - Args: - query (str): sql query - method (str): select from dbsql and dbconnect - cursor (Optional[Cursor]): connection.cursor - spark (Optional[SparkSession]): spark session - collect (bool): whether to get the underlying data from spark dataframe - """ - log.info(f'Executing query using method: {method}') - log.debug(f'Query: {query}') - - if method == 'dbsql': - if cursor is None: - raise ValueError(f'cursor cannot be None if using method dbsql') - cursor.execute(query) - if collect: - return cursor.fetchall() - elif method == 'dbconnect': - if spark == None: - raise ValueError(f'sparkSession is required for dbconnect') - df = spark.sql(query) - if collect: - return df.collect() - return df - else: - raise ValueError(f'Unrecognized method: {method}') - - -def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: - for i, r in enumerate(signed): - yield (i, r.url, json_output_folder, columns) - - -def download( - ipart: int, - url: str, - json_output_folder: str, - columns: Optional[List] = None, - resp_format: str = 'arrow', - compressed: bool = False, -) -> None: - """Thread download presigned url and save to jsonl locally. - - Args: - ipart (int): presigned url id - url (str): presigned url - json_output_folder (str): directory to save the ipart_th segment of dataframe - columns (list): schema to save to json - resp_format (str): whether to use arrow or json when collect - compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. - """ - log.info(f'Downloading part {ipart} from URL: {url}') - - resp = requests.get(url) - if resp.status_code == 200: - if resp_format == 'json': - data = resp.json() - pd.DataFrame(data, columns=columns).to_json( - os.path.join( - json_output_folder, - 'part_' + str(ipart) + '.jsonl', - ), - orient='records', - lines=True, - ) - return - - # When resp_format is arrow: - if compressed: - # The data is lz4 compressed arrow format. - # Decompress the data - decompressed_data = lz4.frame.decompress(resp.content) - # Convert the decompressed data into a PyArrow table - reader = pa.ipc.open_stream(decompressed_data) - else: - reader = pa.ipc.open_stream(resp.content) - table = reader.read_all() - - # Convert the PyArrow table into a pandas DataFrame - df = table.to_pandas() - df.to_json( - os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), - orient='records', - lines=True, - force_ascii=False, - ) - - -def download_starargs(args: Tuple) -> None: - return download(*args) - - -def format_tablename(table_name: str) -> str: - """Escape catalog, schema and table names with backticks. - - This needs to be done when running SQL queries/setting spark sessions to prevent invalid identifier errors. - - Args: - table_name (str): catalog.scheme.tablename on UC - """ - log.debug(f'Formatting table name: {table_name}') - match = re.match(TABLENAME_PATTERN, table_name) - - if match is None: - return table_name - - formatted_identifiers = [] - for i in range(1, 4): - identifier = f'`{match.group(i)}`' - formatted_identifiers.append(identifier) - - return '.'.join(formatted_identifiers) - - -def fetch_data( - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], - start: int, - end: int, - order_by: str, - tablename: str, - columns_str: str, - json_output_folder: str, -) -> None: - """Fetches a specified range of rows from a given table to a json file. - - This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, - from a specified table and column set. The fetched data is then exported as a JSON file. - - Args: - method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. - cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. - sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. - start (int): The starting index for row fetching. - end (int): The ending index for row fetching. - order_by (str): The column name to use for ordering the rows. - tablename (str): The name of the table from which to fetch the data. - columns_str (str): The string representation of the columns to select from the table. - json_output_folder (str): The file path where the resulting JSON file will be saved. - - Returns: - None: The function doesn't return any value, but writes the result to a JSONL file. - """ - log.info(f'Fetching data from {start} to {end} using method: {method}') - query = f""" - WITH NumberedRows AS ( - SELECT - *, - ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn - FROM - {tablename} - ) - SELECT {columns_str} - FROM NumberedRows - WHERE rn BETWEEN {start+1} AND {end}""" - - if method == 'dbconnect': - spark_df = run_query(query, method, cursor, sparkSession, collect=False) - if spark_df is None: - raise RuntimeError( - f'Expect spark dataframe with {query} but got None', - ) - pdf = spark_df.toPandas() # pyright: ignore - else: # method == 'dbsql': - ans = run_query(query, method, cursor, sparkSession, collect=True) - if ans is None: - raise RuntimeError(f'Got empty results with {query}') - records = [r.asDict() for r in ans] # pyright: ignore - pdf = pd.DataFrame.from_dict(records) - - pdf.to_json( - os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), - orient='records', - lines=True, - ) - - -@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) -def get_total_rows( - tablename: str, - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], -): - ans = run_query( - f'SELECT COUNT(*) FROM {tablename}', - method, - cursor, - sparkSession, - ) - nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.info(f'total_rows = {nrows}') - return nrows - - -@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) -def get_columns_info( - tablename: str, - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], -): - ans = run_query( - f'SHOW COLUMNS IN {tablename}', - method, - cursor, - sparkSession, - ) - columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore - order_by = columns[0] - columns_str = ','.join(columns) - log.info(f'order by column {order_by}') - return columns, order_by, columns_str - - -def fetch( - method: str, - tablename: str, - json_output_folder: str, - batch_size: int = 1 << 30, - processes: int = 1, - sparkSession: Optional[SparkSession] = None, - dbsql: Optional[Connection] = None, -) -> None: - """Fetch UC delta table with databricks-connect as JSONL. - - Args: - method (str): dbconnect or dbsql - tablename (str): catalog.scheme.tablename on UC - json_output_folder (str): path to write the result json file to - batch_size (int): number of rows that dbsql fetches each time to avoid OOM - processes (int): max number of processes to use to parallelize the fetch - sparkSession (pyspark.sql.sparksession): spark session - dbsql (databricks.sql.connect): dbsql session - """ - log.info(f'Starting data fetch for table: {tablename}') - log.info( - f'Method: {method}, Batch size: {batch_size}, Processes: {processes}', - ) - - cursor = dbsql.cursor() if dbsql is not None else None - try: - nrows = get_total_rows( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - raise RuntimeError( - f'Error in get rows from {tablename}. Restart sparkSession and try again', - ) from e - - try: - columns, order_by, columns_str = get_columns_info( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e - - if method == 'dbconnect' and sparkSession is not None: - log.info(f'{processes=}') - df = sparkSession.table(tablename) - - # Running the query and collecting the data as arrow or json. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore - log.info(f'len(signed) = {len(signed)}') - - args = get_args(signed, json_output_folder, columns) - - # Stopping the SparkSession to avoid spilling connection state into the subprocesses. - sparkSession.stop() - - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_starargs, args)) - - elif method == 'dbsql' and cursor is not None: - for start in range(0, nrows, batch_size): - log.warning(f'batch {start}') - end = min(start + batch_size, nrows) - fetch_data( - method, - cursor, - sparkSession, - start, - end, - order_by, - tablename, - columns_str, - json_output_folder, - ) - - if cursor is not None: - cursor.close() - - -def validate_and_get_cluster_info( - cluster_id: str, - databricks_host: str, - databricks_token: str, - http_path: Optional[str], - use_serverless: bool = False, -) -> tuple: - """Validate and get cluster info for running the Delta to JSONL conversion. - - Args: - cluster_id (str): cluster id to validate and fetch additional info for - databricks_host (str): databricks host name - databricks_token (str): databricks auth token - http_path (Optional[str]): http path to use for sql connect - use_serverless (bool): whether to use serverless or not - """ - log.info('Validating cluster information and getting connection details') - log.debug( - f'Cluster ID: {cluster_id}, Host: {databricks_host}, Use Serverless: {use_serverless}', - ) - - method = 'dbsql' - dbsql = None - sparkSession = None - - if use_serverless: - method = 'dbconnect' - else: - w = WorkspaceClient() - res = w.clusters.get(cluster_id=cluster_id) - if res is None: - raise ClusterDoesNotExistError(cluster_id) - - assert res.spark_version is not None - stripped_runtime = re.sub( - r'[a-zA-Z]', - '', - res.spark_version.split('-scala') - [0].replace( # type: ignore - 'x-snapshot', '', - ), - ) - runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) - if version.parse( - runtime_version, - ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): - raise ValueError( - f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', - ) - - if http_path is None and version.parse( - runtime_version, - ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): - method = 'dbconnect' - - if method == 'dbconnect': - try: - if use_serverless: - session_id = str(uuid4()) - sparkSession = DatabricksSession.builder.host( - databricks_host, - ).token( - databricks_token, - ).header('x-databricks-session-id', session_id).getOrCreate() - - else: - sparkSession = DatabricksSession.builder.remote( - host=databricks_host, - token=databricks_token, - cluster_id=cluster_id, - ).getOrCreate() - - except Exception as e: - raise FailedToConnectToDatabricksError() from e - else: - try: - dbsql = sql.connect( - server_hostname=re.compile(r'^https?://').sub( - '', databricks_host).strip( - ), # sqlconnect hangs if hostname starts with https - http_path=http_path, - access_token=databricks_token, - ) - except Exception as e: - raise FailedToCreateSQLConnectionError() from e - return method, dbsql, sparkSession - - -def fetch_DT(args: Namespace) -> None: - """Fetch UC Delta Table to local as jsonl.""" - log.info(f'Start .... Convert delta to json') - log.info('Starting Delta Table to JSON conversion process') - log.info(f'Delta Table: {args.delta_table_name}') - log.info(f'Output Folder: {args.json_output_folder}') - log.info(f'Output Filename: {args.json_output_filename}') - - obj = urllib.parse.urlparse(args.json_output_folder) - if obj.scheme != '': - raise ValueError( - 'Check the json_output_folder and verify it is a local path!', - ) - - if os.path.exists(args.json_output_folder): - if not os.path.isdir(args.json_output_folder) or os.listdir( - args.json_output_folder, - ): - raise RuntimeError( - f'Output folder {args.json_output_folder} already exists and is not empty. Please remove it and retry.', - ) - - os.makedirs(args.json_output_folder, exist_ok=True) - - if not args.json_output_filename.endswith('.jsonl'): - raise ValueError('json_output_filename needs to be a jsonl file') - - log.info(f'Directory {args.json_output_folder} created.') - - method, dbsql, sparkSession = validate_and_get_cluster_info( - cluster_id=args.cluster_id, - databricks_host=args.DATABRICKS_HOST, - databricks_token=args.DATABRICKS_TOKEN, - http_path=args.http_path, - use_serverless=args.use_serverless, - ) - - args.delta_table_name = format_tablename(args.delta_table_name) - - fetch( - method, - args.delta_table_name, - args.json_output_folder, - args.batch_size, - args.processes, - sparkSession, - dbsql, - ) - - if dbsql is not None: - dbsql.close() - - # combine downloaded jsonl into one big jsonl for IFT - iterative_combine_jsons( - args.json_output_folder, - os.path.join(args.json_output_folder, args.json_output_filename), - ) - - log.info('Delta Table to JSON conversion completed successfully') - - if __name__ == '__main__': parser = ArgumentParser( description= @@ -719,11 +79,13 @@ def fetch_DT(args: Namespace) -> None: 'The name of the combined final jsonl that combines all partitioned jsonl', ) args = parser.parse_args() - w = WorkspaceClient() - args.DATABRICKS_HOST = w.config.host - args.DATABRICKS_TOKEN = w.config.token - - tik = time.time() - fetch_DT(args) - log.info(f'Elapsed time {time.time() - tik}') - log.info('Delta Table to JSON conversion script completed') + convert_delta_to_json_from_args( + delta_table_name=args.delta_table_name, + json_output_folder=args.json_output_folder, + http_path=args.http_path, + batch_size=args.batch_size, + processes=args.processes, + cluster_id=args.cluster_id, + use_serverless=args.use_serverless, + json_output_filename=args.json_output_filename, + ) diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index 523d45093d..b28e25786b 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -1,28 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json -import os -import platform -import warnings from argparse import ArgumentParser, Namespace -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Union -import datasets as hf_datasets -import psutil from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from streaming import MDSWriter -from torch.utils.data import DataLoader -from tqdm import tqdm -from llmfoundry.data.finetuning.collator import validate_target_settings -from llmfoundry.data.finetuning.tasks import ( - _get_example_type, - dataset_constructor, - is_valid_ift_example, - tokenize_formatted_example, -) -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.command_utils import convert_finetuning_dataset_from_args HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] @@ -116,236 +100,9 @@ def parse_args() -> Namespace: ) parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - if parsed.tokenizer_kwargs is not None: - parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs) - else: - parsed.tokenizer_kwargs = {} - - if len(parsed.data_files) > 0 and len( - parsed.data_files, - ) != len(parsed.splits): - raise ValueError( - f'If data_files is set, data_files and splits must have the same length. Got {len(parsed.data_files)=} while {len(parsed.splits)=}', - ) - return parsed -def build_dataloader( - dataset: HFDataset, - batch_size: int, - num_workers: Optional[int] = None, -) -> DataLoader: - if num_workers is None: - # Multiple workers is only supported on linux machines - if 'linux' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) - else: - num_workers = 0 - - # If using multiple workers, configure each worker to prefetch as many samples as it can, up to - # the aggregate device batch size - # If not using workers, the torch DataLoader expects the default value for prefetch_factor, - # which non-intuitively must be 2. - # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero - if 'macos' in platform.platform().lower() and num_workers == 0: - prefetch_factor = None - else: - prefetch_factor = max( - 1, - 2 * batch_size // num_workers, - ) if num_workers > 0 else 2 - - return DataLoader( - dataset=dataset, - sampler=None, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) - - -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - - -def get_columns_and_format( - dataset: HFDataset, - tokenizing: bool, - preprocessing_fn: Callable, -): - ex = preprocessing_fn(next(iter(dataset))) - example_type = _get_example_type(ex) - if tokenizing: - return {'turns': 'json'}, example_type - if example_type == 'chat': - # Chat format - return {'messages': 'json'}, example_type - else: - # Prompt-response format - return {'prompt': 'str', 'response': 'str'}, example_type - - -def main(args: Namespace) -> None: - """Main: create a streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - if args.skip_preprocessing: - preprocessing_fn = lambda x: x # Just an identity function - else: - preprocessor_str = args.preprocessor - preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( - preprocessor=preprocessor_str, - dataset_name=args.dataset, - ) - if preprocessing_fn is None: - raise ValueError( - '`args.preprocessor` was not set and no preprocessing function ' +\ - 'has been registered for `args.dataset`. If this was intentional ' +\ - '(e.g., because your dataset is already correctly formatted), ' +\ - 'include the "--skip-preprocessing" flag to avoid this error.', - ) - - # Make sure the target settings are valid - validate_target_settings( - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - ) - - tokenizer = None - tokenizer_kwargs = args.tokenizer_kwargs - tokenizer_kwargs.update({'model_max_length': args.max_seq_len}) - if args.tokenizer: - tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs) - - for i, split_name in enumerate(args.splits): - data_file = None - if len(args.data_files) > 0: - data_file = args.data_files[i] - dataset = hf_datasets.load_dataset( - path=args.dataset, - name=args.data_subset, - split=split_name, - data_files=data_file, - streaming=True, - ) - # Determine the output columns - columns, example_type = get_columns_and_format( - dataset=dataset, - tokenizing=tokenizer is not None, - preprocessing_fn=preprocessing_fn, - ) - # Prepare the iterables - if example_type == 'chat': - samples = iter(dataset) - else: - loader = build_dataloader( - dataset=dataset, - batch_size=512, - num_workers=args.num_workers, - ) - samples = generate_samples(loader) - - # Write samples - print(f'Converting {split_name} to MDS format...') - out = os.path.join(args.out_root, split_name) - if args.local is not None: - out = (os.path.join(args.local, split_name), out) - keep_local = True - else: - keep_local = False - with MDSWriter( - columns=columns, - out=out, - compression=args.compression, - keep_local=keep_local, - ) as out: - examples_removed = 0 - for sample in tqdm(samples, desc=split_name): - formatted_sample = preprocessing_fn(sample) - assert isinstance(formatted_sample, dict) - - # Use the _get_example_type utility to confirm that the formatted sample - # can be interpreted by the tokenization code - try: - example_type = _get_example_type(formatted_sample) - except Exception as e: - raise ValueError( - 'Encountered an error when checking example for proper formatting. ' +\ - f'example={formatted_sample}', - ) from e - if tokenizer is not None: - sample = tokenize_formatted_example( - formatted_sample, - tokenizer=tokenizer, - ) - if not is_valid_ift_example( - args.max_seq_len, - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - example=sample, - ): - examples_removed += 1 - continue - - sample_to_write = {'turns': []} - for turn in sample['turns']: - turn_to_write = {} - for key in ['input_ids', 'labels']: - turn_to_write[key] = list(turn[key]) - sample_to_write['turns'].append(turn_to_write) - out.write(sample_to_write) - else: - if example_type == 'prompt_response': - encoded_sample = {} - for key in ['prompt', 'response']: - value = formatted_sample[key] - assert isinstance(value, str) - encoded_sample[key] = value.encode('utf-8') - out.write(encoded_sample) - else: - out.write(formatted_sample) - - if tokenizer is not None and examples_removed > 0: - warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, ' - + - 'the prompt or response was empty, or the response was all padding tokens.', - ) - - if __name__ == '__main__': """Example for converting Muennighoff/P3: @@ -355,4 +112,22 @@ def main(args: Namespace) -> None: >>> --preprocessor llmfoundry.data.finetuning.tasks:p3_preprocessing_function \ >>> --out_root s3:///muennighoff-p3 """ - main(parse_args()) + args = parse_args() + convert_finetuning_dataset_from_args( + dataset=args.dataset, + data_subset=args.data_subset, + splits=args.splits, + preprocessor=args.preprocessor, + data_files=args.data_files, + skip_preprocessing=args.skip_preprocessing, + out_root=args.out_root, + local=args.local, + compression=args.compression, + num_workers=args.num_workers, + tokenizer=args.tokenizer, + tokenizer_kwargs=args.tokenizer_kwargs, + max_seq_len=args.max_seq_len, + target_prompts=args.target_prompts, + target_responses=args.target_responses, + encoder_decoder=args.encoder_decoder, + ) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 8af8280465..c808fa871f 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -2,115 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import math -import os -import tempfile from argparse import ArgumentParser, Namespace -from concurrent.futures import ProcessPoolExecutor -from functools import partial -from glob import glob -from typing import Dict, Iterable, List, Tuple, cast -import numpy as np import psutil -from composer.utils import ( - ObjectStore, - maybe_create_object_store_from_uri, - parse_uri, -) -from numpy.typing import NDArray -from streaming import MDSWriter -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.data.data import AbstractConcatTokensDataset -from llmfoundry.utils.data_prep_utils import ( - DownloadingIterable, - download_file, - merge_shard_groups, -) -from llmfoundry.utils.exceptions import ( - InputFolderMissingDataError, - OutputFolderNotEmptyError, -) +from llmfoundry.command_utils import convert_text_to_mds_from_args log = logging.getLogger(__name__) DONE_FILENAME = '.text_to_mds_conversion_done' -class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): - """An IterableDataset that returns token samples for MDSWriter from files. - - Returns dicts of {'tokens': ndarray:int32} - - Each file is considered a sequence. - """ - - def __init__( - self, - files: Iterable[str], - tokenizer: PreTrainedTokenizerBase, - max_length: int, - bos_text: str, - eos_text: str, - no_wrap: bool, - ): - self.files = files - super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) - log.info(f'Initialized ConcatTokensFromFilesDataset.') - - def __iter__(self) -> Iterable[Dict[str, NDArray]]: - log.info( - 'Starting iteration over files in ConcatTokensFromFilesDataset', - ) - buffer = [] - for file in self.files: - log.info(f'Processing file: {file}') - with open(file, 'r') as f: - buffer += self.bos_tokens - first_chunk = True - # Read the file in 1MB chunks to avoid memory issues - for chunk in iter(partial(f.read, 1000000), ''): - # Tokenize the chunk - encoded = self.tokenizer( - chunk, - truncation=False, - padding=False, - ) - iids = encoded['input_ids'] - - # If this is not the first chunk, remove the BOS token - if not first_chunk: - if iids[0] == self.tokenizer.bos_token_id: - iids = iids[1:] - - # Add the tokens to the buffer - buffer += iids - while len(buffer) >= self.max_length: - concat_sample = buffer[:self.max_length] - buffer = buffer[self. - max_length:] if self.should_wrap else [] - yield { - 'tokens': np.asarray(concat_sample, dtype=np.int32), - } - - first_chunk = False - - # Add the EOS token to the buffer to separate files. - buffer += self.eos_tokens - - # Yield any remaining samples of size max_length. - while len(buffer) >= self.max_length: - concat_sample = buffer[:self.max_length] - buffer = buffer[self.max_length:] if self.should_wrap else [] - yield {'tokens': np.asarray(concat_sample, dtype=np.int32)} - - log.info( - 'Finished iterating over files in ConcatTokensFromFilesDataset', - ) - - def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( @@ -211,454 +113,23 @@ def parse_args() -> Namespace: help='Logging level for the script. Default is INFO.', ) parsed = parser.parse_args() - - # Set eos token. - if parsed.use_tokenizer_eos: - # Ensure that eos text is not specified twice. - if parsed.eos_text is not None: - parser.error( - 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', - ) - tokenizer = AutoTokenizer.from_pretrained( - parsed.tokenizer, - trust_remote_code=parsed.trust_remote_code, - ) - parsed.eos_text = tokenizer.eos_token - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -def get_object_names(input_folder: str) -> List[str]: - """Get object names from a local or remote folder. - - Args: - input_folder (str): local or remote folder path. - """ - object_store = maybe_create_object_store_from_uri(input_folder) - if object_store is not None: - _, _, folder_prefix = parse_uri(input_folder) - names = [ - name for name in object_store.list_objects(folder_prefix) - if name.endswith('.txt') - ] - log.info(f'Found {len(names)} text files in remote storage') - else: - # input_folder is a local folder - names = [ - text_file for dirpath, _, _ in os.walk(input_folder) - for text_file in glob(os.path.join(dirpath, '*.txt')) - ] - # return names, sizes - log.info(f'Found {len(names)} text files at {input_folder}') - - return names - - -def get_task_args( - object_names: List[str], - output_root: str, - input_folder: str, - n_groups: int, - tokenizer_name: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - trust_remote_code: bool, -) -> Iterable: - """Get download_and_convert arguments split across n_groups. - - Each group handles a portion of object_names. - - Args: - object_names (List[str]): Names of objects to process - output_root (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - n_groups (int): Number of groups to split the object names into - tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - log.info( - f'Preparing task arguments for {len(object_names)} objects across {n_groups} groups', - ) - num_objects = len(object_names) - objs_per_group = math.ceil(num_objects / n_groups) - for group, i in enumerate(range(0, num_objects, objs_per_group)): - output_subdir = os.path.join(output_root, str(group)) - log.info( - f'Created task for group {group} with {min(objs_per_group, num_objects - i)} objects', - ) - yield ( - object_names[i:min(i + objs_per_group, num_objects)], - output_subdir, - input_folder, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - - -def download_and_convert_starargs(args: Tuple): - """Helper function to call download_and_convert with star args. - - This helps us use download_and_convert with multiprocessing. - """ - return download_and_convert(*args) - - -def download_and_convert( - file_names: List[str], - output_folder: str, - input_folder: str, - tokenizer_name: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - trust_remote_code: bool, -): - """Downloads and converts text files to MDS format. - - Args: - file_names (List[str]): Files to process - output_folder (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - log.info(f'Starting download and conversion for {len(file_names)} files') - - object_store = maybe_create_object_store_from_uri(input_folder) - - # Download file_names - with tempfile.TemporaryDirectory() as tmp_dir: - log.info(f'Created temporary directory: {tmp_dir}') - downloading_iter = DownloadingIterable( - object_names=file_names, - output_folder=tmp_dir, - object_store=object_store, - ) - log.info(f'Initializing tokenizer: {tokenizer_name}') - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - trust_remote_code=trust_remote_code, - ) - tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace - - # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up - # to the maximum sequence length - dataset = ConcatTokensFromFilesDataset( - files=downloading_iter, - max_length=concat_tokens, - tokenizer=tokenizer, - eos_text=eos_text, - bos_text=bos_text, - no_wrap=no_wrap, - ) - - columns = {'tokens': 'ndarray:int32'} - - log.info('Converting to MDS format...') - with MDSWriter( - out=output_folder, - columns=columns, - compression=compression, - ) as out: - for sample in tqdm(dataset): - out.write(sample) - - log.info(f'Completed download and conversion for {len(file_names)} files') - - -def is_remote_path(path: str) -> bool: - """Checks whether a path is a remote path. - - Args: - path (str): path to check - """ - backend, _, _ = parse_uri(path) - return backend != '' - - -def is_already_processed( - output_root: str, - args_str: str, - object_names: List[str], -) -> bool: - """Determines whether a group of text files has already been processed. - - Checks the done fie at output root to determine this. - - Args: - output_root (str): Output folder where a done file may exist - args_str (str): String representation of the arguments - object_names (List[str]): Names of objects to convert to MDS format - """ - log.info( - f'Checking if {len(object_names)} objects have already been processed in {output_root}', - ) - - # Retrieve the done file contents - output_object_store = maybe_create_object_store_from_uri(output_root) - if output_object_store is not None: - # Download and read the done file from the remote object store - _, _, output_folder_prefix = parse_uri(output_root) - try: - with tempfile.TemporaryDirectory() as tmp_dir: - done_file = os.path.join(tmp_dir, DONE_FILENAME) - download_file( - object_store=output_object_store, - object_name=os.path.join( - output_folder_prefix, - DONE_FILENAME, - ), - output_filename=done_file, - ) - with open(done_file) as df: - done_file_contents = df.read().splitlines() - log.info(f'Retrieved done file contents from remote storage') - except FileNotFoundError: - log.info('Done file not found in remote storage') - return False - else: - # Read the local done file - done_file = os.path.join(output_root, DONE_FILENAME) - if not os.path.isfile(done_file): - log.info('Done file not found in local storage') - return False - with open(done_file) as df: - done_file_contents = df.read().splitlines() - log.info(f'Retrieved done file contents from local storage') - - # Compare the arguments - prev_args_str = done_file_contents[0] - if prev_args_str != args_str: - log.info('Arguments have changed, reprocessing required') - return False - - # Compare file names - prev_names = done_file_contents[1:] - if len(prev_names) != len(object_names): - log.info('Number of files has changed, reprocessing required') - return False - for idx, prev_name in enumerate(prev_names): - if object_names[idx] != prev_name: - log.info('File names have changed, reprocessing required') - return False - - log.info('All files have already been processed') - return True - - -def write_done_file(folder: str, args_str: str, object_names: List[str]): - """Write a file to signify completion. - - This the done file includes the arguments to processing and - a list of objects that were processed. - - Args: - folder (str): Folder to write the done file to - args_str (str): String representation of arguments - object_names (List[str]): List of objects to convert to MDS format - """ - with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: - log.info(f'Writing done file.') - done_file.write('\n'.join([args_str] + object_names) + '\n') - log.info(f'Done file written successfully') - - -def convert_text_to_mds( - tokenizer_name: str, - output_folder: str, - input_folder: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - processes: int, - args_str: str, - reprocess: bool, - trust_remote_code: bool, -): - """Convert a folder of text files to MDS format. - - Args: - tokenizer_name (str): Name of tokenizer to use - output_folder (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - processes (int): The number of processes to use. - args_str (str): String representation of the arguments - reprocess (bool): Whether to always reprocess the given folder of text files - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - is_remote_output = is_remote_path(output_folder) - log.info(f'Output is remote: {is_remote_output}') - - object_names = get_object_names(input_folder) - if len(object_names) == 0: - log.error(f'No text files found in input folder: {input_folder}') - raise InputFolderMissingDataError(input_folder) - - # Check if the text files in the bucket have already been processed. - if not reprocess and is_already_processed( - output_folder, - args_str, - object_names, - ): - log.info( - f'Input folder {input_folder} is already processed at {output_folder} and ' - + - 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', - ) - return - - # Use a temporary local directory if the output is remote and there are more than 1 processes - local_output_folder = tempfile.TemporaryDirectory( - ).name if is_remote_output else output_folder - log.info(f'Using local output folder: {local_output_folder}') - - if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: - log.error(f'Output folder is not empty: {output_folder}') - raise OutputFolderNotEmptyError(output_folder) - - if processes > 1: - log.info(f'Using multiprocessing with {processes} processes') - # Download and convert the text files in parallel - args = get_task_args( - object_names, - local_output_folder, - input_folder, - processes, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_and_convert_starargs, args)) - - log.info('Merging MDS shards from each process') - # Merge the mds shards from each of the processes into a single folder - merge_shard_groups(local_output_folder) - else: - log.info('Using single process for download and conversion') - download_and_convert( - object_names, - local_output_folder, - input_folder, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - - # Write a done file with the args and object names - write_done_file(local_output_folder, args_str, object_names) - - if is_remote_output: - # Upload the local output to the remote location - output_object_store = cast( - ObjectStore, - maybe_create_object_store_from_uri(output_folder), - ) - _, _, output_folder_prefix = parse_uri(output_folder) - files_to_upload = os.listdir(local_output_folder) - - for file in files_to_upload: - assert not os.path.isdir(file) - remote_path = os.path.join(output_folder_prefix, file) - output_object_store.upload_object( - remote_path, - os.path.join(local_output_folder, file), - ) - - -def _args_str(original_args: Namespace) -> str: - """Create a string from the args to determine whether to reprocess. - - Args: - original_args (Namespace): Arguments to main function. - """ - # Take the arguments that influence the final result. - # reprocess and max_mds_writer_workers are not taken. - args = Namespace( - tokenizer_name=original_args.tokenizer, - output_folder=original_args.output_folder, - input_folder=original_args.input_folder, - concat_tokens=original_args.concat_tokens, - eos_text=original_args.eos_text, - bos_text=original_args.bos_text, - no_wrap=original_args.no_wrap, - compression=original_args.compression, - processes=original_args.processes, - ) - - return str(args) - - -def _configure_logging(logging_level: str): - """Configure logging. - - Args: - logging_level (str): Logging level. - """ - logging.basicConfig( - format= - f'%(asctime)s: [%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging_level = logging_level.upper() - logging.getLogger('llmfoundry').setLevel(logging_level) - logging.getLogger(__name__).setLevel(logging_level) - log.info(f'Logging level set to {logging_level}') - - if __name__ == '__main__': args = parse_args() - _configure_logging(args.logging_level) - convert_text_to_mds( - tokenizer_name=args.tokenizer, + convert_text_to_mds_from_args( output_folder=args.output_folder, input_folder=args.input_folder, + compression=args.compression, concat_tokens=args.concat_tokens, - eos_text=args.eos_text, + tokenizer_name=args.tokenizer, bos_text=args.bos_text, + eos_text=args.eos_text, + use_tokenizer_eos=args.use_tokenizer_eos, no_wrap=args.no_wrap, - compression=args.compression, processes=args.processes, reprocess=args.reprocess, trust_remote_code=args.trust_remote_code, - args_str=_args_str(args), + logging_level=args.logging_level, ) diff --git a/scripts/eval/yamls/long_context_tasks.yaml b/scripts/eval/yamls/long_context_tasks.yaml index 153e3b9df6..221635da87 100644 --- a/scripts/eval/yamls/long_context_tasks.yaml +++ b/scripts/eval/yamls/long_context_tasks.yaml @@ -105,7 +105,7 @@ icl_tasks: icl_task_type: generation_task_with_answers hf_loading_vars: name: wikiqa - context_length: 2048 + context_length: 4096 split: test - label: wikiqa_8k @@ -114,7 +114,7 @@ icl_tasks: icl_task_type: generation_task_with_answers hf_loading_vars: name: wikiqa - context_length: 2048 + context_length: 8192 split: test - label: hotpotqa_beginning_2k diff --git a/setup.py b/setup.py index 309d7d3372..04c28d8f70 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import copy import os -import re +from typing import Any, Dict, Mapping import setuptools from setuptools import setup @@ -15,17 +15,15 @@ _REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) _PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) -# Read the repo version +# Read the llm-foundry version # We can't use `.__version__` from the library since it's not installed yet -with open(os.path.join(_PACKAGE_REAL_PATH, '__init__.py')) as f: +version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') +with open(version_path, encoding='utf-8') as f: + version_globals: Dict[str, Any] = {} + version_locals: Mapping[str, object] = {} content = f.read() -# regex: '__version__', whitespace?, '=', whitespace, quote, version, quote -# we put parens around the version so that it becomes elem 1 of the match -expr = re.compile( - r"""^__version__\s*=\s*['"]([0-9]+\.[0-9]+\.[0-9]+(?:\.\w+)?)['"]""", - re.MULTILINE, -) -repo_version = expr.findall(content)[0] + exec(content, version_globals, version_locals) + repo_version = str(version_locals['__version__']) # Use repo README for PyPi description with open('README.md', 'r', encoding='utf-8') as fh: @@ -55,10 +53,10 @@ install_requires = [ 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24', - 'mlflow>=2.14.1,<2.15', - 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.42.3,<4.43', - 'mosaicml-streaming>=0.7.6,<0.8', + 'mlflow>=2.14.1,<2.16', + 'accelerate>=0.25,<0.34', # for HF inference `device_map` + 'transformers>=4.43.2,<4.44', + 'mosaicml-streaming>=0.8.0,<0.9', 'torch>=2.3.0,<2.4', 'datasets>=2.19,<2.20', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data @@ -67,10 +65,10 @@ 'omegaconf>=2.2.3,<3', 'slack-sdk<4', 'mosaicml-cli>=0.6.10,<1', - 'onnx==1.16.1', + 'onnx==1.16.2', 'onnxruntime==1.18.1', 'boto3>=1.21.45,<2', - 'huggingface-hub>=0.19.0,<0.24', + 'huggingface-hub>=0.19.0,<0.25', 'beautifulsoup4>=4.12.2,<5', # required for model download utils 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', @@ -84,7 +82,7 @@ 'pre-commit>=3.4.0,<4', 'pytest>=7.2.1,<8', 'pytest_codeblocks>=0.16.1,<0.17', - 'pytest-cov>=4,<5', + 'pytest-cov>=4,<6', 'pyright==1.1.256', 'toml>=0.10.2,<0.11', 'packaging>=21,<23', @@ -104,7 +102,7 @@ # Flash 2 group kept for backwards compatibility extra_deps['gpu-flash2'] = [ - 'flash-attn==2.5.8', + 'flash-attn>=2.5.8,<3', ] extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2']) diff --git a/tests/a_scripts/data_prep/test_convert_dataset_json.py b/tests/a_scripts/data_prep/test_convert_dataset_json.py index 912e44cd0c..4f70a35637 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_json.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_json.py @@ -2,28 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 import os -from argparse import Namespace from pathlib import Path -from scripts.data_prep.convert_dataset_json import main as main_json +from llmfoundry.command_utils import convert_dataset_json def test_json_script_from_api(tmp_path: Path): # test calling it directly path = os.path.join(tmp_path, 'my-copy-arxiv-1') - main_json( - Namespace( - **{ - 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': path, - 'compression': None, - 'split': 'train', - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_json( + path='scripts/data_prep/example_data/arxiv.jsonl', + out_root=path, + compression=None, + split='train', + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 83d6edeca2..e623467bf7 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -1,15 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# copyright 2022 mosaicml llm foundry authors -# spdx-license-identifier: apache-2.0 - import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, mock_open, patch -from scripts.data_prep.convert_delta_to_json import ( +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( download, fetch_DT, format_tablename, @@ -20,11 +17,19 @@ class TestConvertDeltaToJsonl(unittest.TestCase): - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sdk.WorkspaceClient', + ) def test_stream_delta_to_json( self, mock_workspace_client: Any, @@ -33,19 +38,15 @@ def test_stream_delta_to_json( mock_makedirs: Any, mock_sql_connect: Any, ): - - args = MagicMock() - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' - args.DATABRICKS_HOST = 'test_host' - args.DATABRICKS_TOKEN = 'test_token' - args.http_path = 'test_path' - args.batch_size = 1000 - args.partitions = 1 - args.cluster_id = '1234' - args.debug = False - args.use_serverless = False - args.json_output_filename = 'combined.jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' + DATABRICKS_HOST = 'test_host' + DATABRICKS_TOKEN = 'test_token' + http_path = 'test_path' + batch_size = 1000 + cluster_id = '1234' + use_serverless = False + json_output_filename = 'combined.jsonl' mock_cluster_get = MagicMock() mock_cluster_get.return_value = MagicMock( @@ -53,7 +54,17 @@ def test_stream_delta_to_json( ) mock_workspace_client.return_value.clusters.get = mock_cluster_get - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + batch_size=batch_size, + json_output_filename=json_output_filename, + ) mock_sql_connect.assert_called_once_with( server_hostname='test_host', http_path='test_path', @@ -66,7 +77,9 @@ def test_stream_delta_to_json( '/path/to/jsonl/combined.jsonl', ) - @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.listdir', + ) @patch( 'builtins.open', new_callable=mock_open, @@ -102,7 +115,9 @@ def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): """ self.assertEqual(mock_file().write.call_count, 2) - @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + @patch( + 'pyspark.sql.SparkSession', + ) def test_run_query_dbconnect(self, mock_spark: Any): method = 'dbconnect' mock_cursor = None @@ -118,7 +133,9 @@ def test_run_query_dbconnect(self, mock_spark: Any): mock_spark.sql.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') - @patch('scripts.data_prep.convert_delta_to_json.Cursor') + @patch( + 'databricks.sql.client.Cursor', + ) def test_run_query_dbsql(self, mock_cursor: Any): method = 'dbsql' mock_cursor.fetchall.return_value = 'result' @@ -134,14 +151,18 @@ def test_run_query_dbsql(self, mock_cursor: Any): mock_cursor.execute.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') - @patch('scripts.data_prep.convert_delta_to_json.requests.get') - @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') @patch( - 'scripts.data_prep.convert_delta_to_json.os.path.join', + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.requests.get', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.pd.DataFrame.to_json', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.path.join', return_value='/fake/path/part_1.jsonl', ) @patch( - 'scripts.data_prep.convert_delta_to_json.time.sleep', + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.time.sleep', ) # Mock sleep to speed up the test def test_download_success( self, @@ -174,12 +195,22 @@ def test_download_success( mock_get.assert_called_once_with('http://fakeurl.com/data') - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_dbconnect_called( self, mock_fetch: Any, @@ -189,17 +220,14 @@ def test_dbconnect_called( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = None - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = None + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -209,19 +237,37 @@ def test_dbconnect_called( ) # Mock return value for getOrCreate mock_databricks_session.builder.remote.return_value = mock_remote - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_databricks_session.builder.remote.assert_called_once_with( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id, + host=DATABRICKS_HOST, + token=DATABRICKS_TOKEN, + cluster_id=cluster_id, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_dbr13( self, mock_fetch: Any, @@ -231,34 +277,49 @@ def test_sqlconnect_called_dbr13( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( - server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + server_hostname=DATABRICKS_HOST, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_dbr14( self, mock_fetch: Any, @@ -268,34 +329,49 @@ def test_sqlconnect_called_dbr14( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( - server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + server_hostname=DATABRICKS_HOST, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_https( self, mock_fetch: Any, @@ -305,34 +381,49 @@ def test_sqlconnect_called_https( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'https://test-host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( server_hostname='test-host', - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_serverless( self, mock_fetch: Any, @@ -342,22 +433,27 @@ def test_serverless( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'https://test-host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = True + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'token' + use_serverless = True mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) assert not mock_sql_connect.called assert not mock_databricks_session.builder.remote.called diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index 8dac151f55..6ba14d62e4 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -13,11 +13,7 @@ from streaming import StreamingDataset from transformers import AutoTokenizer -from llmfoundry.utils.exceptions import ( - InputFolderMissingDataError, - OutputFolderNotEmptyError, -) -from scripts.data_prep.convert_text_to_mds import ( +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( DONE_FILENAME, convert_text_to_mds, download_and_convert, @@ -25,6 +21,11 @@ merge_shard_groups, write_done_file, ) +from llmfoundry.utils.exceptions import ( + DatasetTooSmallError, + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) class MockObjectStore(): @@ -83,15 +84,15 @@ def _assert_files_exist(prefix: str, files: List[str]): @pytest.mark.parametrize('processes', [1, 2, 3]) @patch.object(ProcessPoolExecutor, 'map', new=Mock(wraps=_mock_map)) @patch( - 'scripts.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', ) -@patch('scripts.data_prep.convert_text_to_mds.parse_uri') +@patch('llmfoundry.command_utils.data_prep.convert_text_to_mds.parse_uri') @patch( - 'scripts.data_prep.convert_text_to_mds.download_and_convert', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.download_and_convert', wraps=download_and_convert, ) @patch( - 'scripts.data_prep.convert_text_to_mds.merge_shard_groups', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.merge_shard_groups', wraps=merge_shard_groups, ) def test_single_and_multi_process( @@ -267,6 +268,28 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path): ) +def test_dataset_too_small(tmp_path: pathlib.Path): + input_folder = tmp_path / 'input' + os.makedirs(input_folder, exist_ok=True) + with open(input_folder / 'test.txt', 'w') as f: + f.write('a') + with pytest.raises(DatasetTooSmallError): + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(tmp_path / 'output'), + input_folder=str(input_folder), + concat_tokens=2048, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=False, + trust_remote_code=False, + ) + + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) args_str = 'Namespace(x = 5)' diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 2ef458fece..cd47b2df7c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -383,6 +383,14 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' + mlflow_logger_mock._enabled = True + mlflow_logger_mock.run_url = 'fake-url' + checkpointer_callback.transform_model_pre_registration = MagicMock( + wraps=checkpointer_callback.transform_model_pre_registration, + ) + checkpointer_callback.pre_register_edit = MagicMock( + wraps=checkpointer_callback.pre_register_edit, + ) trainer = Trainer( model=original_model, device='gpu', @@ -406,9 +414,14 @@ def test_huggingface_conversion_callback_interval( task='llm/v1/completions', input_example=ANY, metadata={}, + pip_requirements=ANY, ) + assert checkpointer_callback.transform_model_pre_registration.call_count == 1 + assert checkpointer_callback.pre_register_edit.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: + assert checkpointer_callback.transform_model_pre_registration.call_count == 0 + assert checkpointer_callback.pre_register_edit.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 @@ -582,6 +595,7 @@ def _assert_mlflow_logger_calls( 'task': 'llm/v1/completions', 'input_example': default_input_example, 'metadata': {}, + 'pip_requirements': ANY, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 21d73c0d34..1a43e12536 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -7,7 +7,7 @@ import shutil from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import ContextManager, Literal, Optional, Union +from typing import Any, Callable, ContextManager, Dict, Literal, Optional, Union from unittest.mock import MagicMock, patch import catalogue @@ -22,6 +22,8 @@ from streaming.base.util import clean_stale_shared_memory from llmfoundry.command_utils import convert_dataset_hf +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \ + get_columns_and_format from llmfoundry.data import build_dataloader, build_finetuning_dataloader from llmfoundry.data.finetuning.collator import ( _HF_IGNORE_INDEX, @@ -55,8 +57,6 @@ NotEnoughDatasetSamplesError, UnknownExampleTypeError, ) -# yapf: enable -from scripts.data_prep.convert_finetuning_dataset import get_columns_and_format from tests.data_utils import ( make_tiny_conversation_ft_dataset, make_tiny_ft_dataset, @@ -1220,6 +1220,21 @@ def test_token_counting_func_dataloader_setting( 'timeout': 0, } + def build_from_hf( + self, # type: ignore + dataset_name: str, + split: str, + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable] = None, + tokenizer: transformers.PreTrainedTokenizerBase = None, + target_prompts: str = 'last', + target_responses: str = 'none', + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, + ): + return [] + if dataloader_type == 'finetuning-hf': cfg = DictConfig({ 'dataset': { @@ -1235,8 +1250,7 @@ def test_token_counting_func_dataloader_setting( }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', - lambda *args, - **kwargs: [], + build_from_hf, ) dl = build_finetuning_dataloader( tokenizer=gptt, diff --git a/tests/data_utils.py b/tests/data_utils.py index 35e11db531..ea64943735 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -4,16 +4,16 @@ import json import os import shutil -from argparse import Namespace from pathlib import Path from typing import Dict, List, Optional from omegaconf import DictConfig from omegaconf import OmegaConf as om -from llmfoundry.command_utils import convert_dataset_hf -from scripts.data_prep.convert_dataset_json import \ - main as main_json # noqa: E402 +from llmfoundry.command_utils import ( + convert_dataset_hf, + convert_dataset_json, +) def make_tiny_ft_dataset( @@ -265,20 +265,16 @@ def create_arxiv_dataset(path: Path) -> str: if not os.getcwd().endswith('scripts'): arxiv_path = os.path.join('scripts', arxiv_path) - main_json( - Namespace( - **{ - 'path': arxiv_path, - 'out_root': arxiv_dir, - 'compression': None, - 'split': downloaded_split, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_json( + path=arxiv_path, + out_root=arxiv_dir, + compression=None, + split=downloaded_split, + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, ) return arxiv_dir diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index d0ec544de8..844ccd7fe5 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -7,9 +7,11 @@ from unittest.mock import Mock, patch import pytest +import torch from omegaconf import OmegaConf as om from transformers import PretrainedConfig +from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model @@ -235,3 +237,45 @@ def test_nested_override(): assert isinstance(model.config.ffn_config, PretrainedConfig) # Ensure the other values still exist and are not set back to their defaults assert model.config.ffn_config.moe_num_experts == 16 + + +@pytest.mark.gpu +def test_use_flash(): + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + 'torch_dtype': 'bfloat16', + }, + 'pretrained': False, + 'init_device': 'cpu', + 'use_flash_attention_2': True, + } + + name = model_cfg.pop('name') + model = build_composer_model( + name=name, + cfg=model_cfg, + tokenizer=None, # type: ignore + ) + + from transformers.models.llama.modeling_llama import ( + LlamaFlashAttention2, + ) + flash_attn_class = LlamaFlashAttention2 + attention_layers_attr = 'model.model.layers' + attention_attr = 'self_attn' + + # check that it actually used flash attention 2 + assert model.model.config._attn_implementation == ('flash_attention_2') + attention_layer = rgetattr( + rgetattr(model, attention_layers_attr)[0], + attention_attr, + ) + assert isinstance(attention_layer, flash_attn_class) + + # Make sure that HF has not cast the parameters to bf16 + assert next(model.parameters()).dtype == torch.float32 diff --git a/tests/models/hf/test_hf_transform.py b/tests/models/hf/test_hf_transform.py new file mode 100644 index 0000000000..f479b50f73 --- /dev/null +++ b/tests/models/hf/test_hf_transform.py @@ -0,0 +1,76 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +import pytest +from composer.models.huggingface import maybe_get_underlying_model +from peft import PeftConfig, PeftModel +from transformers import LlamaForCausalLM, PreTrainedModel + +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM +from llmfoundry.models.utils import init_empty_weights + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'peft_config', + [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'r': 2, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }, + ], +) +def test_hf_transform(peft_config: Optional[dict]): + model_cfg = { + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + 'pretrained': False, + 'peft_config': peft_config, + 'init_device': 'meta', + 'tokenizer': 'codellama/CodeLlama-7b-hf', + } + + class TransformedHFCausalLM(ComposerHFCausalLM): + + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + assert isinstance(model, LlamaForCausalLM) + with init_empty_weights(): + model.config.num_hidden_layers = 1 + new_model = type(model)(model.config) + return new_model + + def get_peft_config( + self, + peft_config_dict: Dict[str, Any], + ) -> PeftConfig: + peft_config_dict['target_modules'] = ['o_proj'] + return super().get_peft_config(peft_config_dict) + + composer_model = TransformedHFCausalLM(**model_cfg) + model = composer_model.model + inner_model = maybe_get_underlying_model(model) + + if peft_config: + peft_model = composer_model.model + assert isinstance(peft_model, PeftModel) + + target_modules = peft_model.peft_config[peft_model.active_adapter + ].target_modules + assert list(target_modules) == ['o_proj'] + + assert isinstance(inner_model, LlamaForCausalLM) + assert inner_model.config.num_hidden_layers == 1 diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py new file mode 100644 index 0000000000..bdffe2b49f --- /dev/null +++ b/tests/models/layers/test_attention.py @@ -0,0 +1,160 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from llmfoundry.models.layers.layer_builders import build_attention_layer + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_unfused_wqkv(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + attn_config_fused = generic_attn_kwargs.copy() + attn_config_fused['fused_qkv'] = True + + attn_config_unfused = generic_attn_kwargs.copy() + attn_config_unfused['fused_qkv'] = False + + attn_fused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_fused, + ) + attn_unfused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_unfused, + ) + + # Make sure unfused attention has the same params as the fused one. + fused_wqkv = attn_fused.Wqkv.weight.detach().clone() + kv_heads_len = (fused_wqkv.shape[0] - dim) // 2 + Wq_shape_before = (attn_unfused.Wq.weight.shape, attn_unfused.Wq.bias.shape) + Wk_shape_before = (attn_unfused.Wk.weight.shape, attn_unfused.Wk.bias.shape) + Wv_shape_before = (attn_unfused.Wv.weight.shape, attn_unfused.Wv.bias.shape) + + attn_unfused.Wq.weight.data = fused_wqkv[:dim, :] + attn_unfused.Wk.weight.data = fused_wqkv[dim:dim + kv_heads_len, :] + attn_unfused.Wv.weight.data = fused_wqkv[dim + kv_heads_len:, :] + attn_unfused.out_proj.weight.data = attn_fused.out_proj.weight + attn_unfused.Wq.bias.data = attn_fused.Wqkv.bias[:dim] + attn_unfused.Wk.bias.data = attn_fused.Wqkv.bias[dim:dim + kv_heads_len] + attn_unfused.Wv.bias.data = attn_fused.Wqkv.bias[dim + kv_heads_len:] + attn_unfused.out_proj.bias.data = attn_fused.out_proj.bias + + # Make sure initialization fuse splits are as expected. + all_fuse_splits = ( + 0, + [i * d_head for i in range(1, n_heads + 2 * kv_n_heads)], + ) + q_fuse_splits = (0, [i * d_head for i in range(1, n_heads)]) + kv_fuse_splits = (0, [i * d_head for i in range(1, kv_n_heads)]) + + assert attn_fused.Wqkv._fused == all_fuse_splits + assert attn_unfused.Wq._fused == q_fuse_splits + assert attn_unfused.Wk._fused == kv_fuse_splits + assert attn_unfused.Wv._fused == kv_fuse_splits + + assert torch.allclose( + attn_fused.Wqkv.weight, + torch.cat( + [ + attn_unfused.Wq.weight, + attn_unfused.Wk.weight, + attn_unfused.Wv.weight, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.Wqkv.bias, + torch.cat( + [ + attn_unfused.Wq.bias, + attn_unfused.Wk.bias, + attn_unfused.Wv.bias, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.out_proj.weight, + attn_unfused.out_proj.weight, + ) + assert torch.allclose(attn_fused.out_proj.bias, attn_unfused.out_proj.bias) + + assert Wq_shape_before == ( + attn_unfused.Wq.weight.shape, + attn_unfused.Wq.bias.shape, + ) + assert Wk_shape_before == ( + attn_unfused.Wk.weight.shape, + attn_unfused.Wk.bias.shape, + ) + assert Wv_shape_before == ( + attn_unfused.Wv.weight.shape, + attn_unfused.Wv.bias.shape, + ) + + x1 = torch.randn(1, 1, dim) + x2 = x1.detach().clone() + x1.requires_grad = True + x2.requires_grad = True + + out_fused, _, _ = attn_fused(x1) + out_unfused, _, _ = attn_unfused(x2) + + assert torch.allclose(out_fused, out_unfused) + + # Dummy loss function is simply the sum. + loss_fused = out_fused.sum() + loss_fused.backward() + + loss_unfused = out_unfused.sum() + loss_unfused.backward() + + assert isinstance(x1.grad, torch.Tensor) + assert isinstance(x2.grad, torch.Tensor) + assert torch.allclose(x1.grad, x2.grad) + combined_grad = torch.concat( + [ + attn_unfused.Wq.weight.grad, + attn_unfused.Wk.weight.grad, + attn_unfused.Wv.weight.grad, + ], + dim=0, + ) + assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor) + assert isinstance(combined_grad, torch.Tensor) + assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad) diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py new file mode 100644 index 0000000000..bb78763f58 --- /dev/null +++ b/tests/models/layers/test_ffn.py @@ -0,0 +1,73 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +from llmfoundry.models.layers.ffn import quickgelu_activation +from llmfoundry.models.layers.layer_builders import build_ffn + + +@pytest.mark.gpu +def test_quickgelu_activation(): + d_model = 32 + expansion_ratio = 1 + no_bias = True + ffn_config = { + 'ffn_act_fn': { + 'name': 'quick_gelu', + }, + 'ffn_type': 'mptmlp', + } + rank: int = dist.get_rank() + device_str = f'cuda:{rank}' + device: torch.device = torch.device(device_str) + + ffn1 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + assert ( + ffn1.act == quickgelu_activation + ), f'Expected quick_gelu activation function, got {ffn1.act}' + + ffn_config = { + 'ffn_act_fn': { + 'name': 'gelu', + }, + 'ffn_type': 'mptmlp', + } + ffn2 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + + def num_params(model: nn.Module) -> int: + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([p.numel() for p in model_parameters]) + + ffn1_numparams = num_params(ffn1) + ffn2_numparams = num_params(ffn2) + assert ( + ffn1_numparams == ffn2_numparams + ), 'Only activation paths should have changed, re-check modeling!' + + input_ = torch.rand(1, d_model, device=device) + output1 = ffn1(input_) + output2 = ffn2(input_) + assert ( + output1.numel() == output2.numel() + ), 'Only activation paths should have changed, re-check modeling!' + assert ( + not torch.allclose(output1, output2) + ), 'Functions are different, outputs should not match!' diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 01d982052f..4bfdfb84dc 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens @@ -664,12 +665,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg['d_model'] // cfg['n_heads'], rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg['d_model'], + n_heads=cfg['n_heads'], ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 45378e42bd..ed40e7a88a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -35,8 +35,7 @@ ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms @@ -48,7 +47,7 @@ ) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel -from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry +from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -2924,7 +2923,7 @@ def test_hf_rotary_child_class_builds(): list(range(max_seq_len)), ] * bsz) - rot_emb_mp = HFRotaryEmbeddingFoundry( + rot_emb_mp = LlamaRotaryEmbeddingFoundry( rope_head_dim, max_seq_len, rope_theta, @@ -2932,7 +2931,7 @@ def test_hf_rotary_child_class_builds(): ) cos_mp, sin_mp = rot_emb_mp(value, position_ids) - rot_emb = HFRotaryEmbedding( + rot_emb = LlamaRotaryEmbedding( rope_head_dim, max_seq_len, rope_theta, diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 6a41e64f48..34fb23f670 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } dail_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=dail_rope_config['rope_impl'], rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], rope_hf_config={}, max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') dail_rope_w_meta_info = { 'impl': 'dail', @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } hf_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=hf_rope_config['rope_impl'], rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, rope_hf_config=hf_rope_config['rope_hf_config'], max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py new file mode 100644 index 0000000000..484ac2b23a --- /dev/null +++ b/tests/models/test_rope_scaling.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + +rope_config = { + 'rope_theta': 500000.0, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'factor': 8.0, + 'low_freq_factor': 1.0, + 'high_freq_factor': 4.0, + 'original_max_position_embeddings': 8192, + 'type': 'llama3', + }, +} + +rope_dail_config = {} + + +def test_rope_scaling(): + d_model = 128 + n_heads = 32 + max_seq_len = 65536 + + embedding = gen_rotary_embedding( + d_model=d_model, + n_heads=n_heads, + rope_dail_config=rope_dail_config, + max_seq_len=max_seq_len, + **rope_config, + ) + + assert isinstance(embedding, LlamaRotaryEmbedding) diff --git a/tests/test_registry.py b/tests/test_registry.py index 7ee95442c8..c4d1a1bcd5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -24,6 +24,7 @@ def test_expected_registries_exist(): 'loggers', 'optimizers', 'schedulers', + 'tokenizers', 'callbacks', 'algorithms', 'callbacks_with_config', @@ -44,6 +45,8 @@ def test_expected_registries_exist(): 'fcs', 'icl_datasets', 'config_transforms', + 'load_planners', + 'save_planners', } assert existing_registries == expected_registry_names diff --git a/tests/tokenizers/test_registry.py b/tests/tokenizers/test_registry.py new file mode 100644 index 0000000000..920c207a64 --- /dev/null +++ b/tests/tokenizers/test_registry.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +from transformers import PreTrainedTokenizer + +from llmfoundry.registry import tokenizers +from llmfoundry.utils import build_tokenizer + + +class DummyTokenizer(PreTrainedTokenizer): + """A dummy tokenizer that inherits from ``PreTrainedTokenizer``.""" + + def __init__( + self, + model_name: Optional[str] = 'dummy', + **kwargs: Optional[Dict[str, Any]], + ): + """Dummy constructor that has no real purpose.""" + super().__init__( + model_name=model_name, + eos_token='0', + pad_token='1', + **kwargs, + ) + + def get_vocab(self) -> Dict[str, int]: + return {} + + +def test_tokenizer_registry(): + tokenizers.register('dummy', func=DummyTokenizer) + tokenizer = build_tokenizer(tokenizer_name='dummy', tokenizer_kwargs={}) + assert type(tokenizer) == DummyTokenizer diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index dfcb5b327c..fb6cb0c5df 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -13,17 +13,24 @@ from composer.callbacks import Generate from composer.core import Evaluator from composer.loggers import WandBLogger +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, +) from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer +from llmfoundry.registry import load_planners, save_planners from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import ( add_metrics_to_eval_loaders, build_callback, build_eval_loaders, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_tokenizer, ) @@ -345,6 +352,34 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): assert eval_loaders2[1].metric_names == [] +def test_build_load_planner(): + # Dummy LoadPlanner for testing + class DummyLoadPlanner(DefaultLoadPlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + load_planners.register('dummy', func=DummyLoadPlanner) + load_planner = build_load_planner('dummy', is_test=True) + + assert isinstance(load_planner, DummyLoadPlanner) + assert load_planner.is_test is True + + +def test_build_save_planner(): + # Dummy SavePlanner for testing + class DummySavePlanner(DefaultSavePlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + save_planners.register('dummy', func=DummySavePlanner) + save_planner = build_save_planner('dummy', is_test=True) + + assert isinstance(save_planner, DummySavePlanner) + assert save_planner.is_test is True + + def test_add_metrics_to_eval_loaders(): evaluators = [ Evaluator(