From e9a65cb704cf47ca8527d5fd35719a7a955864b4 Mon Sep 17 00:00:00 2001 From: Ben Wang Date: Mon, 1 Mar 2021 11:01:06 +1100 Subject: [PATCH] check TPU support and benchmark, add vscode run config for tpu ssh remote development --- .gitignore | 1 + .vscode/launch.json | 16 ++++++++++++++ .vscode/settings.json | 3 +++ README.md | 13 ++++++++++-- requirements.txt | 2 +- tests/bench_tpu.py | 44 +++++++++++++++++++++++++++++++++++++++ tests/test_consistency.py | 4 +++- tpu_install.sh | 4 ++++ 8 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 tests/bench_tpu.py create mode 100755 tpu_install.sh diff --git a/.gitignore b/.gitignore index 321f181f9..1f5d15b8e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ *.egg-info .pytest_cache .ipynb_checkpoints +venv/ thumbs.db .DS_Store diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..5dba21720 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..615aafb03 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "/usr/bin/python3" +} \ No newline at end of file diff --git a/README.md b/README.md index 64802521e..e25bd1876 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a The ViT model and checkpoints have been ported to Haiku, while preserving the same output. See `tests/test_consistency.py` for details. No JIT/pmap is performed, but pure inference functions for both the text and image encoders are provided from the the -`clip_jax.load()` function which should be easy to run/parallelize how you wish. +`clip_jax.load()` function which should be easy to run/parallelize how you wish. See `test/tpu_bench.py` for an example of using pmap. ## Usage Example @@ -28,8 +28,17 @@ image_embed = image_fn(jax_params, image) text_embed = text_fn(jax_params, text) ``` +## TPU performance + +On a TPU v3-8 with Jax tpu-vm alpha (`test/tpu_bench.py`): +``` +10.1361s to compile model +43.9599s for 16 batches +5963.25 examples/s +``` + ## TODOs -- [ ] Test on TPUs +- [x] Test on TPUs - [ ] Easier control over precision and device placement - [ ] Mixed precision training support - [ ] Support RN50 model \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b87d9887a..7966b6d80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ regex tqdm torch~=1.7.1 torchvision~=0.8.2 -git+https://github.com/deepmind/dm-haiku +dm-haiku jax jaxlib \ No newline at end of file diff --git a/tests/bench_tpu.py b/tests/bench_tpu.py new file mode 100644 index 000000000..1e36d58c7 --- /dev/null +++ b/tests/bench_tpu.py @@ -0,0 +1,44 @@ +import numpy as np +from PIL import Image +import jax +import time + +import clip_jax + +image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load('ViT-B/32', "cpu") + +batch_size = 2048 + +devices = jax.local_devices() + +print(f"jax devices: {devices}") + +jax_params = jax.device_put_replicated(jax_params, devices) +image_fn = jax.pmap(image_fn) +text_fn = jax.pmap(text_fn) + +jax_image = np.expand_dims(jax_preprocess(Image.open("CLIP.png")), (0, 1)) +jax_image = np.repeat(jax_image, len(devices), axis=0) +jax_image = np.repeat(jax_image, batch_size, axis=1) + +jax_text = np.expand_dims(clip_jax.tokenize(["a diagram"]), 0) +jax_text = np.repeat(jax_text, len(devices), axis=0) +jax_text = np.repeat(jax_text, batch_size, axis=1) + +start = time.time() +jax_image_embed = image_fn(jax_params, jax_image) +jax_text_embed = text_fn(jax_params, jax_text) +total = time.time() - start +print(f"{total:.06}s to compile model") + +start = time.time() +for i in range(16): + jax_image_embed = np.array(image_fn(jax_params, jax_image)) + jax_text_embed = np.array(text_fn(jax_params, jax_text)) + +total = time.time() - start + +print(f"{total:.06}s for 16 batches@bs={batch_size} per core") +print(f"{16*len(devices) * batch_size/total:.06} examples/s") + +print("done!") \ No newline at end of file diff --git a/tests/test_consistency.py b/tests/test_consistency.py index c2231dc43..f52e7dc47 100644 --- a/tests/test_consistency.py +++ b/tests/test_consistency.py @@ -21,4 +21,6 @@ pyt_text_embed = pytorch_clip.encode_text(pyt_text) assert np.allclose(np.array(jax_image_embed), pyt_image_embed.cpu().detach().numpy(), atol=0.01, rtol=0.01) -assert np.allclose(np.array(jax_text_embed), pyt_text_embed.cpu().detach().numpy(), atol=0.01, rtol=0.01) \ No newline at end of file +assert np.allclose(np.array(jax_text_embed), pyt_text_embed.cpu().detach().numpy(), atol=0.01, rtol=0.01) + +print("done!") \ No newline at end of file diff --git a/tpu_install.sh b/tpu_install.sh new file mode 100755 index 000000000..8692a2835 --- /dev/null +++ b/tpu_install.sh @@ -0,0 +1,4 @@ +pip3 install --upgrade --user jax jaxlib==0.1.61 +pip3 install -r requirements.txt +pip3 install -e . +pip3 install git+https://github.com/openai/CLIP.git \ No newline at end of file