Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support building GPU docker image for MaxDiffusion Model #121

Merged
merged 9 commits into from
Oct 29, 2024
Merged

Conversation

sizhit2
Copy link
Collaborator

@sizhit2 sizhit2 commented Oct 8, 2024

Add device option in docker_build_dependency_image.sh. The default option is still Tpu. To build docker image with GPU, just specify bash docker_build_dependency_image.sh DEVICE=gpu
Add maxdiffusion_gpu_dependencies.Dockerfile .

parambole
parambole previously approved these changes Oct 8, 2024
Copy link
Collaborator

@parambole parambole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for adding this functionality.

Nit: Can you please update the README with the instructions?

@@ -0,0 +1,49 @@
# syntax=docker/dockerfile:experimental
ARG BASEIMAGE=ghcr.io/nvidia/jax:base
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sizhit2 

Nit:

We can probaly add a link from where the base image is being dervice or a description even though navigating using the base image link would be sufficient ) i.e. why is this the default
Also probably a Note that this pulls in the latest of the jax:base i.e. built nightly AFAIR and that the tag changes so only the tag ( dated image on that day is stable i.e. does not chage )

Super Nit:

If you feel like contributing back to Jax tool box build process you could add a maxdiffusion test since it now depends on jax toolbox base as one of the CI/CD test.

@parambole parambole added the enhancement New feature or request label Oct 8, 2024
Copy link
Collaborator

@wang2yn84 wang2yn84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding the GPU support! Left couple of comments.

@@ -46,6 +46,11 @@ if [[ -z ${MODE} ]]; then
echo "Default MODE=${MODE}"
fi

if [[ -z ${DEVICE} ]]; then
export DEVICE=tpu
echo "Default DEVICE=${DEVICE}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this echo outside of the if statement so that it always prints the default device type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be confusing if we set DEVICE=gpu, but we are echoing Default DEVICE=tpu? Or do you suggest we always echo the device information we are currently using?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't get the device as TPU if we set to GPU. This piece of code only sets the device type as TPU is no device type is assigned. So if you assign device type as GPU, you will get GPU.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are suggesting. I add additional line to always print current DEVICE.

if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
if [[ ${DEVICE} == "gpu" ]]; then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you added a "pinned" mode here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

…etup, add hardware gpu option in yml, add jax multi-host support for gpu
…ment to fix import error, add jax[cuda] install instruction more non-pinned mode when device is GPU
@sizhit2 sizhit2 changed the title Adding gpu docker dependency file Adding GPU support for MaxDiffusion Model Oct 23, 2024
setup.sh Outdated Show resolved Hide resolved
constraints_gpu.txt Outdated Show resolved Hide resolved
constraints_gpu.txt Outdated Show resolved Hide resolved
@sizhit2 sizhit2 changed the title Adding GPU support for MaxDiffusion Model Support building GPU docker image for MaxDiffusion Model Oct 28, 2024
setup.sh Outdated Show resolved Hide resolved
Comment on lines +18 to +78
set_nccl_gpudirect_tcpx_specific_configuration() {
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_CROSS_NIC=0
export NCCL_DEBUG=INFO
export NCCL_DYNAMIC_CHUNK_SIZE=524288
export NCCL_NET_GDR_LEVEL=PIX
export NCCL_NVLS_ENABLE=0
export NCCL_P2P_NET_CHUNKSIZE=524288
export NCCL_P2P_NVL_CHUNKSIZE=1048576
export NCCL_P2P_PCI_CHUNKSIZE=524288
export NCCL_PROTO=Simple
export NCCL_SOCKET_IFNAME=eth0
export NVTE_FUSED_ATTN=1
export TF_CPP_MAX_LOG_LEVEL=100
export TF_CPP_VMODULE=profile_guided_latency_estimator=10
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
shopt -s globstar nullglob
IFS=:$IFS
set -- /usr/local/cuda-*/compat
export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
IFS=${IFS#?}
shopt -u globstar nullglob

if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then
echo "Using GPUDirect-TCPX"
export NCCL_ALGO=Ring
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
export NCCL_MAX_NCHANNELS=12
export NCCL_MIN_NCHANNELS=12
export NCCL_NSOCKS_PERTHREAD=4
export NCCL_P2P_PXN_LEVEL=0
export NCCL_SOCKET_NTHREADS=1
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
echo "Using GPUDirect-TCPFasTrak"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export NCCL_ALGO=Ring,Tree
export NCCL_BUFFSIZE=8388608
export NCCL_FASTRAK_CTRL_DEV=eth0
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8
export NCCL_FASTRAK_NUM_FLOWS=2
export NCCL_FASTRAK_USE_LLCM=1
export NCCL_FASTRAK_USE_SNAP=1
export NCCL_MIN_NCHANNELS=4
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
export NCCL_TUNER_PLUGIN=libnccl-tuner.so
fi
else
echo "NOT using GPUDirect"
fi
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sizhit2 Currently are we deploying the workload using XPK ? I am trying to understand which components determines the source of through for these flags.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parambole , yes we are deploying through xpk. It will execute this file and execute this function. Can you elaborate on "which components determines the source of through for these flags"? What does it mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wang2yn84 Oh I see in the past there was a conversation around XPK setting the required env variables. But as you mentioned I still see XPK calling this file for that.

ref: https://github.com/AI-Hypercomputer/xpk/blob/d88b092dc71a9d3f6d06fc8984370de0fbcbbe51/src/xpk/core/core.py#L2133

"which components determines the source of through for these flags"?

Here the two components are XPK and the workload I thought XPK was taking care of setting the envvar and all the other setup related to running a workload on A3 series like installing the NCCL plugin and running the side-car container.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Do you have a pointer to that? What's the conclusion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Do you have a pointer to that? What's the conclusion?

I am not sure what the conclusion was. Probably @gobbleturk might have more context.

@parambole parambole merged commit 51e1db1 into main Oct 29, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants