Skip to content

Commit

Permalink
check TPU support and benchmark, add vscode run config for tpu ssh re…
Browse files Browse the repository at this point in the history
…mote development
  • Loading branch information
kingoflolz committed Mar 1, 2021
1 parent 61a82fd commit e9a65cb
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ __pycache__/
*.egg-info
.pytest_cache
.ipynb_checkpoints
venv/

thumbs.db
.DS_Store
Expand Down
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [

{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.pythonPath": "/usr/bin/python3"
}
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a
The ViT model and checkpoints have been ported to Haiku, while preserving the same output. See `tests/test_consistency.py` for details.

No JIT/pmap is performed, but pure inference functions for both the text and image encoders are provided from the the
`clip_jax.load()` function which should be easy to run/parallelize how you wish.
`clip_jax.load()` function which should be easy to run/parallelize how you wish. See `test/tpu_bench.py` for an example of using pmap.

## Usage Example

Expand All @@ -28,8 +28,17 @@ image_embed = image_fn(jax_params, image)
text_embed = text_fn(jax_params, text)
```

## TPU performance

On a TPU v3-8 with Jax tpu-vm alpha (`test/tpu_bench.py`):
```
10.1361s to compile model
43.9599s for 16 batches
5963.25 examples/s
```

## TODOs
- [ ] Test on TPUs
- [x] Test on TPUs
- [ ] Easier control over precision and device placement
- [ ] Mixed precision training support
- [ ] Support RN50 model
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ regex
tqdm
torch~=1.7.1
torchvision~=0.8.2
git+https://github.com/deepmind/dm-haiku
dm-haiku
jax
jaxlib
44 changes: 44 additions & 0 deletions tests/bench_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from PIL import Image
import jax
import time

import clip_jax

image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load('ViT-B/32', "cpu")

batch_size = 2048

devices = jax.local_devices()

print(f"jax devices: {devices}")

jax_params = jax.device_put_replicated(jax_params, devices)
image_fn = jax.pmap(image_fn)
text_fn = jax.pmap(text_fn)

jax_image = np.expand_dims(jax_preprocess(Image.open("CLIP.png")), (0, 1))
jax_image = np.repeat(jax_image, len(devices), axis=0)
jax_image = np.repeat(jax_image, batch_size, axis=1)

jax_text = np.expand_dims(clip_jax.tokenize(["a diagram"]), 0)
jax_text = np.repeat(jax_text, len(devices), axis=0)
jax_text = np.repeat(jax_text, batch_size, axis=1)

start = time.time()
jax_image_embed = image_fn(jax_params, jax_image)
jax_text_embed = text_fn(jax_params, jax_text)
total = time.time() - start
print(f"{total:.06}s to compile model")

start = time.time()
for i in range(16):
jax_image_embed = np.array(image_fn(jax_params, jax_image))
jax_text_embed = np.array(text_fn(jax_params, jax_text))

total = time.time() - start

print(f"{total:.06}s for 16 batches@bs={batch_size} per core")
print(f"{16*len(devices) * batch_size/total:.06} examples/s")

print("done!")
4 changes: 3 additions & 1 deletion tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@
pyt_text_embed = pytorch_clip.encode_text(pyt_text)

assert np.allclose(np.array(jax_image_embed), pyt_image_embed.cpu().detach().numpy(), atol=0.01, rtol=0.01)
assert np.allclose(np.array(jax_text_embed), pyt_text_embed.cpu().detach().numpy(), atol=0.01, rtol=0.01)
assert np.allclose(np.array(jax_text_embed), pyt_text_embed.cpu().detach().numpy(), atol=0.01, rtol=0.01)

print("done!")
4 changes: 4 additions & 0 deletions tpu_install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pip3 install --upgrade --user jax jaxlib==0.1.61
pip3 install -r requirements.txt
pip3 install -e .
pip3 install git+https://github.com/openai/CLIP.git

0 comments on commit e9a65cb

Please sign in to comment.