Skip to content

Commit

Permalink
Ver 1.4.4: fix bug in stopping audio playback
Browse files Browse the repository at this point in the history
  • Loading branch information
AnyaCoder committed Nov 17, 2024
1 parent ad6cb43 commit 75af8a6
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 66 deletions.
132 changes: 100 additions & 32 deletions fish/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def __init__(self):
self.state = ChatState()
self.thread_pool = QThreadPool.globalInstance()
self.loop = asyncio.get_event_loop()
self.async_msg_task = None
self.initUI()
self.init_messages()

Expand Down Expand Up @@ -537,10 +538,15 @@ def initUI(self):
self.clear_button.clicked.connect(self.clear_messages)
self.clear_button.setStyleSheet(CLEAN_QSS)

self.stop_button = QPushButton(_t("ChatWidget.stop_btn"))
self.stop_button.clicked.connect(self.stop_message_task)
self.stop_button.setStyleSheet(CLEAN_QSS)

input_layout.addWidget(self.voice_mode_button)
input_layout.addWidget(self.input_field)
input_layout.addWidget(self.send_button)
input_layout.addWidget(self.clear_button)
input_layout.addWidget(self.stop_button)
main_layout.addLayout(input_layout)

# Add the settings button as an overlay in the top-right corner
Expand Down Expand Up @@ -670,10 +676,7 @@ def add_message(
)
self.scroll_layout.addWidget(bubble)
QApplication.processEvents()
# Drag scroll bar to the bottom
self.scroll_area.verticalScrollBar().setValue(
self.scroll_area.verticalScrollBar().maximum()
)

return bubble

def clear_messages(self):
Expand Down Expand Up @@ -715,10 +718,19 @@ def start_message_task(self, *, text: str = None, audio: str = None):
message_worker.update_bubble_signal.connect(self.on_update_bubble)
message_worker.update_duration_signal.connect(self.on_update_duration)
message_worker.update_text_signal.connect(self.on_update_text)

# worker -> QRunnable -> QThreadPool
async_task = AsyncTaskWorker(message_worker)
self.thread_pool.start(async_task)
self.async_msg_task = AsyncTaskWorker(message_worker)
self.thread_pool.start(self.async_msg_task)
pass

def stop_message_task(self):
if self.async_msg_task:
self.async_msg_task.cancel()
self.async_msg_task = None

def on_message_task_finished(self, audio):
self.audio_files.append(audio)
logger.info("Message Task Complete")
pass

def on_add_message(self, text, is_sender, is_voice, audio, duration):
Expand Down Expand Up @@ -746,21 +758,24 @@ def on_update_text(self, text):
item: MessageBubble = self.scroll_layout.itemAt(num_bubbles - 1).widget()
item.update_text(text)

def on_message_task_finished(self, audio):
self.audio_files.append(audio)
logger.info("Message Task Complete")
pass

def update_bubble_size(self, mode: str = "all"):
start_idx = self.scroll_layout.count() - 1 if mode == "last" else 0
for i in range(start_idx, self.scroll_layout.count(), 1):
item = self.scroll_layout.itemAt(i).widget()
if isinstance(item, MessageBubble):
item.msg.setFixedWidth(item.get_dynamic_width(self.width()))
if mode == "last":
# Drag scroll bar to the bottom
self.scroll_area.verticalScrollBar().setValue(
self.scroll_area.verticalScrollBar().maximum()
)
pass

def keyPressEvent(self, event):
if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter):
if event.modifiers() == Qt.Modifier.CTRL and event.key() == Qt.Key.Key_B:
self.stop_message_task()
event.accept()
elif event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter):
if self.input_field.hasFocus():
self.send_message_text()
else:
Expand Down Expand Up @@ -820,8 +835,9 @@ def __init__(
self.system_prompt = system_prompt
self.system_audios = system_audios
self.loop = loop
self._task = None

async def send_message_async(self):
async def send_message_async(self, cancel_event: asyncio.Event):
text = self.input_text
audio = self.input_audio
agent = self.agent
Expand Down Expand Up @@ -875,32 +891,84 @@ async def send_message_async(self):
self.update_bubble_signal.emit("last")

# Step 3: Generate audio and text segments in real-time
async def infostream_generator():
async def wave_generator(audio_data: bytes, cancel_event: asyncio.Event):
chunk_size = 32768 # 32KB = 16K samples = 16384 / 44100 = 0.372 s
offset = 0

while offset + chunk_size <= len(audio_data):
# one method to stop async audioplayer is to cut off the wav-stream
if cancel_event.is_set():
break
yield audio_data[offset : offset + chunk_size]
offset += chunk_size

if cancel_event.is_set():
yield b""
elif offset < len(audio_data):
yield audio_data[offset:]

async def infostream_generator(cancel_event: asyncio.Event):
total_seg_time = 0.0
yield wav_chunk_header()
async for event in agent.stream(
chat_ctx={"messages": self.state.conversation}
):
if event.type == FishE2EEventType.SPEECH_SEGMENT:
self.state.append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
total_seg_time += len(event.vq_codes[0]) / 21
yield bytes(event.frame.data)
self.update_duration_signal.emit(total_seg_time)
elif event.type == FishE2EEventType.TEXT_SEGMENT:
self.state.append_to_chat_ctx(ServeTextPart(text=event.text))
self.update_text_signal.emit(
self.state.repr_message(self.state.conversation[-1])
)
self.update_bubble_signal.emit("last")
yield wav_chunk_header() # Initial header

try:
async for event in agent.stream(
chat_ctx={"messages": self.state.conversation}
):
if cancel_event.is_set():
break

if event.type == FishE2EEventType.SPEECH_SEGMENT:
self.state.append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
total_seg_time += len(event.vq_codes[0]) / 21

audio_data = bytes(event.frame.data)
async for chunk in wave_generator(audio_data, cancel_event):
yield chunk

self.update_duration_signal.emit(total_seg_time)

elif event.type == FishE2EEventType.TEXT_SEGMENT:
self.state.append_to_chat_ctx(ServeTextPart(text=event.text))
self.update_text_signal.emit(
self.state.repr_message(self.state.conversation[-1])
)
self.update_bubble_signal.emit("last")

except asyncio.CancelledError:
logger.warning("Infostream generator was cancelled.")
raise # Re-raise to assure interruption

# Step 4: Play audio (streaming)

audio_player = AudioPlayWorker(audio_path=temp_wavfile, streaming=True)
await audio_player.run_async(infostream_generator())
audio_player.set_chunks(infostream_generator(cancel_event))
await audio_player.run_async()
self.finished.emit(temp_wavfile)

def run(self):
# Run asynchronous tasks in a new event loop using asyncio.run
self.loop.run_until_complete(self.send_message_async())
self.cancel_event = asyncio.Event()
self._task = self.loop.create_task(self.send_message_async(self.cancel_event))
self._task.add_done_callback(self.on_task_done)
try:
self.loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass # Don't show redundant error

def cancel(self):
if self._task:
self.cancel_event.set()
self._task.cancel()
self._task = None

def on_task_done(self, task: asyncio.Task):
if task.cancelled():
logger.warning("Task was cancelled")
elif task.exception():
logger.error(f"Task encountered an exception: {task.exception()}")
else:
logger.info("Task completed successfully")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions fish/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ def start_conversion(self):
**kwargs,
)
self.tts_worker.finished_signal.connect(self.on_conversion_finished)
self.tts_worker.error_signal.connect(self.stop_conversion)
self.tts_worker.packet_delay.connect(
lambda t: self.latency_label.setText(
_t("action.latency").format(latency=(t * 1000.0))
Expand Down
83 changes: 50 additions & 33 deletions fish/modules/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import wave
from pathlib import Path
from typing import Iterator, List
from typing import AsyncIterator, Iterator, List

import numpy as np
import ormsgpack
Expand Down Expand Up @@ -145,24 +145,22 @@ def __init__(
self,
audio_path: str,
streaming: bool,
iterable_chunks: Iterator[bytes] = None,
frames_per_buffer: int = 16384,
):
super().__init__()
self.audio_path = audio_path
self.streaming = streaming
self.iterable_chunks = iterable_chunks
self.frames_per_buffer = frames_per_buffer
self.is_interrupted = False
self.iterable_chunks = None

self.is_interrupted = False
self.elapsed = 0
self.p = None
self.stream = None

self.time_worker = TimeWorker(pause_time=0.1)
self.time_worker.time_signal.connect(self.calc_elapsed)

# Sync Methods:
def calc_elapsed(self, elapsed):
self.elapsed = elapsed
self.packet_delay.emit(elapsed)
Expand Down Expand Up @@ -190,6 +188,8 @@ def start_audio_streaming(self):

def audio_streaming(self):
first_packet_time = None
if not self.iterable_chunks:
return
for chunk in self.iterable_chunks:
if self.is_interrupted:
break
Expand All @@ -203,29 +203,11 @@ def audio_streaming(self):
first_packet_time = self.elapsed
self.time_worker.stop()

def stop_audio_streaming(self):
if self.streaming and self.stream:
self.stream.stop_stream()
self.stream.close()
self.p.terminate()
self.f.close()

def run(self):
self.start_audio_streaming()
if self.iterable_chunks:
self.audio_streaming()
self.stop_audio_streaming()
if not self.is_interrupted:
self.finished_signal.emit(self.audio_path)

def stop(self):
self.is_interrupted = True
self.time_worker.stop()

# Async Methods:
async def async_audio_streaming(self, async_chunks):
async def async_audio_streaming(self):
first_packet_time = None
async for chunk in async_chunks:
if not self.iterable_chunks:
return
async for chunk in self.iterable_chunks:
if self.is_interrupted:
break
if self.streaming:
Expand All @@ -238,17 +220,44 @@ async def async_audio_streaming(self, async_chunks):
first_packet_time = self.elapsed
self.time_worker.stop()

async def run_async(self, async_chunks):
self.time_worker.start()
def stop_audio_streaming(self):
if self.streaming and self.stream:
self.stream.stop_stream()
self.stream.close()
self.p.terminate()
self.f.close()
logger.info("Playback Finished")

def set_chunks(self, chunks: Iterator[bytes] | AsyncIterator[bytes] = None):
self.iterable_chunks = chunks

def run(self):
logger.info("Sync Playback Started")
self.start_audio_streaming()
await self.async_audio_streaming(async_chunks)
self.audio_streaming()
self.stop_audio_streaming()
if not self.is_interrupted:
logger.info("Sync Playback Finished")
self.finished_signal.emit(self.audio_path)

async def run_async(self):
logger.info("Async Playback Started")
self.start_audio_streaming()
await self.async_audio_streaming()
self.stop_audio_streaming()
if not self.is_interrupted:
logger.info("Async Playback Finished")
self.finished_signal.emit(self.audio_path)

def stop(self):
self.is_interrupted = True
self.time_worker.stop()
logger.info("Playback Stopped")


class TTSWorker(AudioPlayWorker):
error_signal = pyqtSignal()

def __init__(
self,
ref_files: List[str],
Expand All @@ -267,6 +276,7 @@ def __init__(
self.text = text
self.api_key = api_key
self.kwargs = kwargs
self.streaming = streaming

def get_pre_files(self):
return [f for f in self.ref_files if not f.endswith(".lab")]
Expand Down Expand Up @@ -316,14 +326,15 @@ def run(self):
},
)
response.raise_for_status()
self.iterable_chunks = response.iter_content(
chunk_size=self.frames_per_buffer
)
audio_chunks = response.iter_content(chunk_size=self.frames_per_buffer)
self.set_chunks(audio_chunks)
super().run()
except requests.RequestException as e:
logger.error(f"Network request failed: {e}")
self.error_signal.emit()
finally:
self.stop() # Ensure the thread stops gracefully if there's an error
response.close()


class AudioRecordWorker(QThread):
Expand Down Expand Up @@ -413,3 +424,9 @@ def __init__(self, worker):

def run(self):
self.worker.run()

def cancel(self):
self.worker.cancel()

def stop(self):
self.worker.stop()
1 change: 1 addition & 0 deletions locales/en_US.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,5 @@ ChatWidget:
send_btn: "Send"
clear_confirm: "Are you sure to clear chat history?"
clear_btn: "Clear"
stop_btn: "Stop"
recording: "Recording: {dur:.1f} s"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fish-speech-gui"
version = "1.4.3"
version = "1.4.4"
description = "fish-speech comfortable GUI"
readme = "README.md"
requires-python = "<3.12,>=3.10"
Expand Down

0 comments on commit 75af8a6

Please sign in to comment.