Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support building GPU docker image for MaxDiffusion Model #121

Merged
merged 9 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ if [[ -z ${MODE} ]]; then
echo "Default MODE=${MODE}"
fi

if [[ -z ${DEVICE} ]]; then
export DEVICE=tpu
echo "Default DEVICE=${DEVICE}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this echo outside of the if statement so that it always prints the default device type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be confusing if we set DEVICE=gpu, but we are echoing Default DEVICE=tpu? Or do you suggest we always echo the device information we are currently using?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't get the device as TPU if we set to GPU. This piece of code only sets the device type as TPU is no device type is assigned. So if you assign device type as GPU, you will get GPU.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are suggesting. I add additional line to always print current DEVICE.

fi
echo "DEVICE=${DEVICE}"

if [[ -z ${JAX_VERSION+x} ]] ; then
export JAX_VERSION=NONE
echo "Default JAX_VERSION=${JAX_VERSION}"
Expand All @@ -55,22 +61,31 @@ COMMIT_HASH=$(git rev-parse --short HEAD)

echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ."

if [[ "${MODE}" == "stable_stack" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
if [[ ${DEVICE} == "gpu" ]]; then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you added a "pinned" mode here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if [[ ${MODE} == "pinned" ]]; then
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
if [[ "${MODE}" == "stable_stack" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
fi
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
fi
156 changes: 156 additions & 0 deletions gpu_multi_process_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#! /bin/bash
set -e
set -u
set -o pipefail

: "${NNODES:?Must set NNODES}"
: "${NODE_RANK:?Must set NODE_RANK}"
: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}"
: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}"
: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}"
: "${COMMAND:?Must set COMMAND}"


export GPUS_PER_NODE=$GPUS_PER_NODE
export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT
export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS

set_nccl_gpudirect_tcpx_specific_configuration() {
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_CROSS_NIC=0
export NCCL_DEBUG=INFO
export NCCL_DYNAMIC_CHUNK_SIZE=524288
export NCCL_NET_GDR_LEVEL=PIX
export NCCL_NVLS_ENABLE=0
export NCCL_P2P_NET_CHUNKSIZE=524288
export NCCL_P2P_NVL_CHUNKSIZE=1048576
export NCCL_P2P_PCI_CHUNKSIZE=524288
export NCCL_PROTO=Simple
export NCCL_SOCKET_IFNAME=eth0
export NVTE_FUSED_ATTN=1
export TF_CPP_MAX_LOG_LEVEL=100
export TF_CPP_VMODULE=profile_guided_latency_estimator=10
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
shopt -s globstar nullglob
IFS=:$IFS
set -- /usr/local/cuda-*/compat
export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
IFS=${IFS#?}
shopt -u globstar nullglob

if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then
echo "Using GPUDirect-TCPX"
export NCCL_ALGO=Ring
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
export NCCL_MAX_NCHANNELS=12
export NCCL_MIN_NCHANNELS=12
export NCCL_NSOCKS_PERTHREAD=4
export NCCL_P2P_PXN_LEVEL=0
export NCCL_SOCKET_NTHREADS=1
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
echo "Using GPUDirect-TCPFasTrak"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export NCCL_ALGO=Ring,Tree
export NCCL_BUFFSIZE=8388608
export NCCL_FASTRAK_CTRL_DEV=eth0
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8
export NCCL_FASTRAK_NUM_FLOWS=2
export NCCL_FASTRAK_USE_LLCM=1
export NCCL_FASTRAK_USE_SNAP=1
export NCCL_MIN_NCHANNELS=4
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
export NCCL_TUNER_PLUGIN=libnccl-tuner.so
fi
else
echo "NOT using GPUDirect"
fi
}
Comment on lines +18 to +78
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sizhit2 Currently are we deploying the workload using XPK ? I am trying to understand which components determines the source of through for these flags.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parambole , yes we are deploying through xpk. It will execute this file and execute this function. Can you elaborate on "which components determines the source of through for these flags"? What does it mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wang2yn84 Oh I see in the past there was a conversation around XPK setting the required env variables. But as you mentioned I still see XPK calling this file for that.

ref: https://github.com/AI-Hypercomputer/xpk/blob/d88b092dc71a9d3f6d06fc8984370de0fbcbbe51/src/xpk/core/core.py#L2133

"which components determines the source of through for these flags"?

Here the two components are XPK and the workload I thought XPK was taking care of setting the envvar and all the other setup related to running a workload on A3 series like installing the NCCL plugin and running the side-car container.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Do you have a pointer to that? What's the conclusion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Do you have a pointer to that? What's the conclusion?

I am not sure what the conclusion was. Probably @gobbleturk might have more context.


echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}"

set_nccl_gpudirect_tcpx_specific_configuration

wait_all_success_or_exit() {
# https://www.baeldung.com/linux/background-process-get-exit-code
local pids=("$@")
while [[ ${#pids[@]} -ne 0 ]]; do
all_success="true"
for pid in "${pids[@]}"; do
code=$(non_blocking_wait "$pid")
if [[ $code -ne 127 ]]; then
if [[ $code -ne 0 ]]; then
echo "PID $pid failed with exit code $code"
exit "$code"
fi
else
all_success="false"
fi
done
if [[ $all_success == "true" ]]; then
echo "All pids succeeded"
break
fi
sleep 5
done
}
non_blocking_wait() {
# https://www.baeldung.com/linux/background-process-get-exit-code
local pid=$1
local code=127 # special code to indicate not-finished
if [[ ! -d "/proc/$pid" ]]; then
wait "$pid"
code=$?
fi
echo $code
}

resolve_coordinator_ip() {
local lookup_attempt=1
local max_coordinator_lookups=500
local coordinator_found=false
local coordinator_ip_address=""

echo "Coordinator Address $JAX_COORDINATOR_ADDRESS"

while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do
coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1)
if [[ -n "$coordinator_ip_address" ]]; then
coordinator_found=true
echo "Coordinator IP address: $coordinator_ip_address"
export JAX_COORDINATOR_IP=$coordinator_ip_address
return 0
else
echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..."
((lookup_attempt++))
sleep 1
fi
done

if [[ "$coordinator_found" = false ]]; then
echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts."
return 1
fi
}

# Resolving coordinator IP
set +e
resolve_coordinator_ip
set -e

PIDS=()
eval ${COMMAND} &
PID=$!
PIDS+=($PID)

wait_all_success_or_exit "${PIDS[@]}"
2 changes: 1 addition & 1 deletion maxdiffusion_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ARG JAX_VERSION
ENV ENV_JAX_VERSION=$JAX_VERSION

# Set the working directory in the container
WORKDIR /app
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
Expand Down
50 changes: 50 additions & 0 deletions maxdiffusion_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# syntax=docker/dockerfile:experimental
# Note: This pulls in the lastest of jax:base
ARG BASEIMAGE=ghcr.io/nvidia/jax:base
FROM $BASEIMAGE

# Stopgaps measure to circumvent gpg key setup issue.
RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list

# Install dependencies for adjusting network rto
RUN apt-get update && apt-get install -y iproute2 ethtool lsof

# Install DNS util dependencies
RUN apt-get install -y dnsutils

# Add the Google Cloud SDK package repository
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk

# Set environment variables for Google Cloud SDK
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"

# Upgrade libcusprase to work with Jax
RUN apt-get update && apt-get install -y libcusparse-12-3

ARG MODE
ENV ENV_MODE=$MODE

ARG JAX_VERSION
ENV ENV_JAX_VERSION=$JAX_VERSION

ARG DEVICE
ENV ENV_DEVICE=$DEVICE

RUN mkdir -p /deps

# Set the working directory in the container
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
RUN ls .

RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}


WORKDIR /deps
4 changes: 2 additions & 2 deletions maxdiffusion_jax_stable_stack_tpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ ARG COMMIT_HASH

ENV COMMIT_HASH=$COMMIT_HASH

RUN mkdir -p /app
RUN mkdir -p /deps

# Set the working directory in the container
WORKDIR /app
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
Expand Down
4 changes: 2 additions & 2 deletions maxdiffusion_runner.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ ARG BASEIMAGE=maxdiffusion_base_image
FROM $BASEIMAGE

# Set the working directory in the container
WORKDIR /app
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .

WORKDIR /app
WORKDIR /deps
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ git+https://github.com/mlperf/logging.git
opencv-python==4.10.0.84
orbax-checkpoint>=0.5.20
tokenizers==0.20.0
huggingface_hub==0.24.7
Loading
Loading