-
Notifications
You must be signed in to change notification settings - Fork 15
/
server.py
168 lines (130 loc) · 4.14 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import threading
import os
from time import sleep
import requests
from contextlib import asynccontextmanager
from fastapi import (
FastAPI,
WebSocket,
status,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from listener_counter import ListenerCounter
from logger import log_info, log_warn
from websocket_connection_manager import WebSocketConnectionManager
# the index of the current audio track from 0 to 9
current_index = -1
# the timer that periodically advances the current audio track
t = None
inference_url = ""
api_key = ""
ws_connection_manager = WebSocketConnectionManager()
listener_counter = ListenerCounter()
@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, inference_url, api_key
inference_url = os.environ.get("INFERENCE_SERVER_URL")
api_key = os.environ.get("API_KEY")
if not inference_url:
inference_url = "http://localhost:8001"
advance()
yield
if t:
t.cancel()
def generate_new_audio():
if not inference_url:
return
global current_index
offset = 0
if current_index == 0:
offset = 5
elif current_index == 5:
offset = 0
else:
return
log_info("requesting new audio...")
try:
requests.post(
f"{inference_url}/generate",
headers={"Authorization": f"key {api_key}"},
)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
is_available = False
while not is_available:
try:
res = requests.post(
f"{inference_url}/clips/0",
stream=True,
headers={"Authorization": f"key {api_key}"},
)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return
if res.status_code != status.HTTP_200_OK:
print(res.status_code)
print("still generating...")
sleep(30)
continue
print("inference complete! downloading new clips")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
for i in range(4):
res = requests.post(
f"{inference_url}/clips/{i + 1}",
stream=True,
headers={"Authorization": f"key {api_key}"},
)
if res.status_code != status.HTTP_200_OK:
continue
with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)
log_info("audio generated.")
def advance():
global current_index, t
if current_index == 9:
current_index = 0
else:
current_index = current_index + 1
threading.Thread(target=generate_new_audio).start()
t = threading.Timer(60, advance)
t.start()
app = FastAPI(lifespan=lifespan)
@app.get("/current.mp3")
def get_current_audio():
return FileResponse(f"{current_index}.mp3")
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws_connection_manager.connect(ws)
addr = ""
if ws.client:
addr, _ = ws.client
else:
await ws.close()
ws_connection_manager.disconnect(ws)
return
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
try:
while True:
msg = await ws.receive_text()
match msg:
case "listening":
listener_counter.add_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
case "paused":
listener_counter.remove_listener(addr)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
except:
listener_counter.remove_listener(addr)
ws_connection_manager.disconnect(ws)
await ws_connection_manager.broadcast(f"{listener_counter.count()}")
app.mount("/", StaticFiles(directory="web", html=True), name="web")