-
Notifications
You must be signed in to change notification settings - Fork 9
/
ray_tpu.py
89 lines (66 loc) · 2.67 KB
/
ray_tpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import functools
import os
import subprocess
import time
import glob
import requests
from fabric import Connection
@functools.lru_cache()
def get_bearer():
return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip()
@functools.lru_cache()
def get_project():
return subprocess.check_output('gcloud config list --format "value(core.project)"', shell=True).decode(
"utf-8").strip()
def check_tpu(name, zone):
headers = {
"Authorization": f"Bearer {get_bearer()}",
}
response = requests.get(
f"https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}",
headers=headers)
return response.json()
def get_connection(
name,
zone,
):
info = check_tpu(name, zone)
outputs = []
for i in info["networkEndpoints"]:
outputs.append(Connection(i["ipAddress"],
connect_kwargs={
"key_filename": os.path.expanduser("~/.ssh/google_compute_engine"), }))
return outputs
def start_ray(conn, address):
# start afresh each launch (temporarily)
conn.run("sudo rm -rf *.py bloom-jax-inference")
# make directory of structure: bloom_inference/bloom_inference/modeling_bloom
conn.run("mkdir bloom-jax-inference bloom-jax-inference/bloom_inference bloom-jax-inference/bloom_inference/modeling_bloom -p")
# copy run files into bloom_inference
for i in glob.glob("*.py"):
conn.put(i, "bloom-jax-inference/")
# copy CPU/TPU manager files into bloom_inference/bloom_inference
for i in glob.glob("bloom_inference/*.py"):
conn.put(i, "bloom-jax-inference/bloom_inference/")
# copy CPU/TPU manager files into bloom_inference/bloom_inference
for i in glob.glob("scripts/*.sh"):
conn.put(i, "bloom_inference/scripts/")
# copy modeling files into bloom_inference/bloom_inference/modeling_bloom
for i in glob.glob("bloom_inference/modeling_bloom/*.py"):
conn.put(i, "bloom-jax-inference/bloom_inference/modeling_bloom/")
# copy modeling files into bloom_inference/bloom_inference/modeling_bloom
for i in glob.glob("*.sh"):
conn.put(i, "bloom-jax-inference/")
# copy key files into bloom_inference
conn.put("key.json", "bloom-jax-inference/")
# transfer start-up script from CPU -> hosts and give permissions
conn.sudo("chmod +x bloom_inference/scripts/ray_tpu.sh", hide=True)
try:
conn.run("ray stop -f", hide=True)
except:
pass
time.sleep(1)
# run start-up script
out = conn.run(f"bash /tmp/ray-tpu.sh {address}", hide=False)
# display result
print(out)