From 6448e4ef290448d011447fd78e463edbe3e60f64 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Mon, 21 Oct 2024 16:08:19 -0400 Subject: [PATCH] Remove TE from dockerfile and instead add as optional dependency (#1605) --- .github/workflows/docker.yaml | 3 --- .github/workflows/release.yaml | 2 -- Dockerfile | 4 ---- llmfoundry/models/mpt/configuration_mpt.py | 7 ++----- setup.py | 8 ++++++-- 5 files changed, 8 insertions(+), 16 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 976b4241ab..c3fc9168ee 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -20,11 +20,9 @@ jobs: - name: "2.4.0_cu124" base_image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 dep_groups: "[all]" - te_commit: 901e5d2 - name: "2.4.0_cu124_aws" base_image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04-aws dep_groups: "[all]" - te_commit: 901e5d2 steps: - name: Checkout @@ -91,4 +89,3 @@ jobs: BRANCH_NAME=${{ github.head_ref || github.ref_name }} BASE_IMAGE=${{ matrix.base_image }} DEP_GROUPS=${{ matrix.dep_groups }} - TE_COMMIT=${{ matrix.te_commit }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 3617732c8f..15c83035e0 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -95,7 +95,6 @@ jobs: build-args: | BASE_IMAGE=mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04-aws BRANCH_NAME=${{ env.BRANCH_NAME }} - TE_COMMIT=901e5d2 DEP_GROUPS=[all] KEEP_FOUNDRY=true @@ -111,6 +110,5 @@ jobs: build-args: | BASE_IMAGE=mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 BRANCH_NAME=${{ env.BRANCH_NAME }} - TE_COMMIT=901e5d2 DEP_GROUPS=[all] KEEP_FOUNDRY=true diff --git a/Dockerfile b/Dockerfile index a9d44bfa27..f2566cd3cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,6 @@ FROM $BASE_IMAGE ARG BRANCH_NAME ARG DEP_GROUPS -ARG TE_COMMIT ARG KEEP_FOUNDRY=false ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" @@ -16,9 +15,6 @@ ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py setup.py RUN rm setup.py -# Install TransformerEngine -RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@$TE_COMMIT - # Install and uninstall foundry to cache foundry requirements RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}" diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index dbcabdf5f9..1adb64dc21 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -371,11 +371,8 @@ def _validate_config(self) -> None: del te # unused except: raise ImportError( - 'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. ' - + - 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' - + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + - 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156', + 'TransformerEngine import failed. `fc_type: te` requires TransformerEngine be installed, ', + 'e.g. pip install transformer-engine[pytorch]', ) self.ffn_config['fc_type'] = self.fc_type diff --git a/setup.py b/setup.py index a05715a83f..ae98a36f5d 100644 --- a/setup.py +++ b/setup.py @@ -123,14 +123,18 @@ 'grouped-gemm==0.1.6', ] +extra_deps['te'] = [ + 'transformer-engine[pytorch]>=1.11.0,<1.12', +] + extra_deps['databricks-serverless'] = { dep for key, deps in extra_deps.items() for dep in deps - if 'gpu' not in key and 'megablocks' not in key and + if 'gpu' not in key and 'megablocks' not in key and 'te' not in key and 'databricks-connect' not in dep } extra_deps['all-cpu'] = { dep for key, deps in extra_deps.items() for dep in deps - if 'gpu' not in key and 'megablocks' not in key + if 'gpu' not in key and 'megablocks' not in key and 'te' not in key } extra_deps['all'] = { dep for key, deps in extra_deps.items() for dep in deps