-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support building GPU docker image for MaxDiffusion Model #121
Changes from all commits
42e8d10
bb182c4
af65f68
c6030e9
44b4801
0a15164
0e475bd
e5f76c1
a89d21c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,12 @@ if [[ -z ${MODE} ]]; then | |
echo "Default MODE=${MODE}" | ||
fi | ||
|
||
if [[ -z ${DEVICE} ]]; then | ||
export DEVICE=tpu | ||
echo "Default DEVICE=${DEVICE}" | ||
fi | ||
echo "DEVICE=${DEVICE}" | ||
|
||
if [[ -z ${JAX_VERSION+x} ]] ; then | ||
export JAX_VERSION=NONE | ||
echo "Default JAX_VERSION=${JAX_VERSION}" | ||
|
@@ -55,22 +61,31 @@ COMMIT_HASH=$(git rev-parse --short HEAD) | |
|
||
echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ." | ||
|
||
if [[ "${MODE}" == "stable_stack" ]]; then | ||
if [[ ! -v BASEIMAGE ]]; then | ||
echo "Erroring out because BASEIMAGE is unset, please set it!" | ||
exit 1 | ||
if [[ ${DEVICE} == "gpu" ]]; then | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you added a "pinned" mode here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if [[ ${MODE} == "pinned" ]]; then | ||
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17 | ||
else | ||
export BASEIMAGE=ghcr.io/nvidia/jax:base | ||
fi | ||
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . | ||
else | ||
if [[ "${MODE}" == "stable_stack" ]]; then | ||
if [[ ! -v BASEIMAGE ]]; then | ||
echo "Erroring out because BASEIMAGE is unset, please set it!" | ||
exit 1 | ||
fi | ||
docker build --no-cache \ | ||
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ | ||
--build-arg COMMIT_HASH=${COMMIT_HASH} \ | ||
--network=host \ | ||
-t ${LOCAL_IMAGE_NAME} \ | ||
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile . | ||
else | ||
docker build --no-cache \ | ||
--network=host \ | ||
--build-arg MODE=${MODE} \ | ||
--build-arg JAX_VERSION=${JAX_VERSION} \ | ||
-t ${LOCAL_IMAGE_NAME} \ | ||
-f maxdiffusion_dependencies.Dockerfile . | ||
fi | ||
docker build --no-cache \ | ||
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ | ||
--build-arg COMMIT_HASH=${COMMIT_HASH} \ | ||
--network=host \ | ||
-t ${LOCAL_IMAGE_NAME} \ | ||
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile . | ||
else | ||
docker build --no-cache \ | ||
--network=host \ | ||
--build-arg MODE=${MODE} \ | ||
--build-arg JAX_VERSION=${JAX_VERSION} \ | ||
-t ${LOCAL_IMAGE_NAME} \ | ||
-f maxdiffusion_dependencies.Dockerfile . | ||
fi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#! /bin/bash | ||
set -e | ||
set -u | ||
set -o pipefail | ||
|
||
: "${NNODES:?Must set NNODES}" | ||
: "${NODE_RANK:?Must set NODE_RANK}" | ||
: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}" | ||
: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}" | ||
: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}" | ||
: "${COMMAND:?Must set COMMAND}" | ||
|
||
|
||
export GPUS_PER_NODE=$GPUS_PER_NODE | ||
export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT | ||
export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS | ||
|
||
set_nccl_gpudirect_tcpx_specific_configuration() { | ||
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then | ||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
export NCCL_CROSS_NIC=0 | ||
export NCCL_DEBUG=INFO | ||
export NCCL_DYNAMIC_CHUNK_SIZE=524288 | ||
export NCCL_NET_GDR_LEVEL=PIX | ||
export NCCL_NVLS_ENABLE=0 | ||
export NCCL_P2P_NET_CHUNKSIZE=524288 | ||
export NCCL_P2P_NVL_CHUNKSIZE=1048576 | ||
export NCCL_P2P_PCI_CHUNKSIZE=524288 | ||
export NCCL_PROTO=Simple | ||
export NCCL_SOCKET_IFNAME=eth0 | ||
export NVTE_FUSED_ATTN=1 | ||
export TF_CPP_MAX_LOG_LEVEL=100 | ||
export TF_CPP_VMODULE=profile_guided_latency_estimator=10 | ||
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 | ||
shopt -s globstar nullglob | ||
IFS=:$IFS | ||
set -- /usr/local/cuda-*/compat | ||
export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64" | ||
IFS=${IFS#?} | ||
shopt -u globstar nullglob | ||
|
||
if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then | ||
echo "Using GPUDirect-TCPX" | ||
export NCCL_ALGO=Ring | ||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION | ||
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0 | ||
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0 | ||
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000 | ||
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191" | ||
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4 | ||
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177" | ||
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000 | ||
export NCCL_MAX_NCHANNELS=12 | ||
export NCCL_MIN_NCHANNELS=12 | ||
export NCCL_NSOCKS_PERTHREAD=4 | ||
export NCCL_P2P_PXN_LEVEL=0 | ||
export NCCL_SOCKET_NTHREADS=1 | ||
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then | ||
echo "Using GPUDirect-TCPFasTrak" | ||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 | ||
export NCCL_ALGO=Ring,Tree | ||
export NCCL_BUFFSIZE=8388608 | ||
export NCCL_FASTRAK_CTRL_DEV=eth0 | ||
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0 | ||
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 | ||
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8 | ||
export NCCL_FASTRAK_NUM_FLOWS=2 | ||
export NCCL_FASTRAK_USE_LLCM=1 | ||
export NCCL_FASTRAK_USE_SNAP=1 | ||
export NCCL_MIN_NCHANNELS=4 | ||
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto | ||
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto | ||
export NCCL_TUNER_PLUGIN=libnccl-tuner.so | ||
fi | ||
else | ||
echo "NOT using GPUDirect" | ||
fi | ||
} | ||
Comment on lines
+18
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sizhit2 Currently are we deploying the workload using XPK ? I am trying to understand which components determines the source of through for these flags. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @parambole , yes we are deploying through xpk. It will execute this file and execute this function. Can you elaborate on "which components determines the source of through for these flags"? What does it mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wang2yn84 Oh I see in the past there was a conversation around XPK setting the required env variables. But as you mentioned I still see XPK calling this file for that.
Here the two components are XPK and the workload I thought XPK was taking care of setting the envvar and all the other setup related to running a workload on A3 series like installing the NCCL plugin and running the side-car container. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Do you have a pointer to that? What's the conclusion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am not sure what the conclusion was. Probably @gobbleturk might have more context. |
||
|
||
echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}" | ||
|
||
set_nccl_gpudirect_tcpx_specific_configuration | ||
|
||
wait_all_success_or_exit() { | ||
# https://www.baeldung.com/linux/background-process-get-exit-code | ||
local pids=("$@") | ||
while [[ ${#pids[@]} -ne 0 ]]; do | ||
all_success="true" | ||
for pid in "${pids[@]}"; do | ||
code=$(non_blocking_wait "$pid") | ||
if [[ $code -ne 127 ]]; then | ||
if [[ $code -ne 0 ]]; then | ||
echo "PID $pid failed with exit code $code" | ||
exit "$code" | ||
fi | ||
else | ||
all_success="false" | ||
fi | ||
done | ||
if [[ $all_success == "true" ]]; then | ||
echo "All pids succeeded" | ||
break | ||
fi | ||
sleep 5 | ||
done | ||
} | ||
non_blocking_wait() { | ||
# https://www.baeldung.com/linux/background-process-get-exit-code | ||
local pid=$1 | ||
local code=127 # special code to indicate not-finished | ||
if [[ ! -d "/proc/$pid" ]]; then | ||
wait "$pid" | ||
code=$? | ||
fi | ||
echo $code | ||
} | ||
|
||
resolve_coordinator_ip() { | ||
local lookup_attempt=1 | ||
local max_coordinator_lookups=500 | ||
local coordinator_found=false | ||
local coordinator_ip_address="" | ||
|
||
echo "Coordinator Address $JAX_COORDINATOR_ADDRESS" | ||
|
||
while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do | ||
coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1) | ||
if [[ -n "$coordinator_ip_address" ]]; then | ||
coordinator_found=true | ||
echo "Coordinator IP address: $coordinator_ip_address" | ||
export JAX_COORDINATOR_IP=$coordinator_ip_address | ||
return 0 | ||
else | ||
echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..." | ||
((lookup_attempt++)) | ||
sleep 1 | ||
fi | ||
done | ||
|
||
if [[ "$coordinator_found" = false ]]; then | ||
echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts." | ||
return 1 | ||
fi | ||
} | ||
|
||
# Resolving coordinator IP | ||
set +e | ||
resolve_coordinator_ip | ||
set -e | ||
|
||
PIDS=() | ||
eval ${COMMAND} & | ||
PID=$! | ||
PIDS+=($PID) | ||
|
||
wait_all_success_or_exit "${PIDS[@]}" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# syntax=docker/dockerfile:experimental | ||
# Note: This pulls in the lastest of jax:base | ||
ARG BASEIMAGE=ghcr.io/nvidia/jax:base | ||
FROM $BASEIMAGE | ||
|
||
# Stopgaps measure to circumvent gpg key setup issue. | ||
RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list | ||
|
||
# Install dependencies for adjusting network rto | ||
RUN apt-get update && apt-get install -y iproute2 ethtool lsof | ||
|
||
# Install DNS util dependencies | ||
RUN apt-get install -y dnsutils | ||
|
||
# Add the Google Cloud SDK package repository | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list | ||
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - | ||
|
||
# Install the Google Cloud SDK | ||
RUN apt-get update && apt-get install -y google-cloud-sdk | ||
|
||
# Set environment variables for Google Cloud SDK | ||
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}" | ||
|
||
# Upgrade libcusprase to work with Jax | ||
RUN apt-get update && apt-get install -y libcusparse-12-3 | ||
|
||
ARG MODE | ||
ENV ENV_MODE=$MODE | ||
|
||
ARG JAX_VERSION | ||
ENV ENV_JAX_VERSION=$JAX_VERSION | ||
|
||
ARG DEVICE | ||
ENV ENV_DEVICE=$DEVICE | ||
|
||
RUN mkdir -p /deps | ||
|
||
# Set the working directory in the container | ||
WORKDIR /deps | ||
|
||
# Copy all files from local workspace into docker container | ||
COPY . . | ||
RUN ls . | ||
|
||
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" | ||
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} | ||
|
||
|
||
WORKDIR /deps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this echo outside of the if statement so that it always prints the default device type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be confusing if we set
DEVICE=gpu
, but we are echoingDefault DEVICE=tpu
? Or do you suggest we always echo the device information we are currently using?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We won't get the device as TPU if we set to GPU. This piece of code only sets the device type as TPU is no device type is assigned. So if you assign device type as GPU, you will get GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you are suggesting. I add additional line to always print current DEVICE.