From 75af8a66a455cce820e862ea7298bbf50d79affd Mon Sep 17 00:00:00 2001 From: AnyaCoder Date: Mon, 18 Nov 2024 00:05:03 +0800 Subject: [PATCH] Ver 1.4.4: fix bug in stopping audio playback --- fish/chat.py | 132 +++++++++++++++++++++++++++++++---------- fish/gui.py | 1 + fish/modules/worker.py | 83 +++++++++++++++----------- locales/en_US.yaml | 1 + pyproject.toml | 2 +- 5 files changed, 153 insertions(+), 66 deletions(-) diff --git a/fish/chat.py b/fish/chat.py index 8f4a7f9..7108bda 100644 --- a/fish/chat.py +++ b/fish/chat.py @@ -485,6 +485,7 @@ def __init__(self): self.state = ChatState() self.thread_pool = QThreadPool.globalInstance() self.loop = asyncio.get_event_loop() + self.async_msg_task = None self.initUI() self.init_messages() @@ -537,10 +538,15 @@ def initUI(self): self.clear_button.clicked.connect(self.clear_messages) self.clear_button.setStyleSheet(CLEAN_QSS) + self.stop_button = QPushButton(_t("ChatWidget.stop_btn")) + self.stop_button.clicked.connect(self.stop_message_task) + self.stop_button.setStyleSheet(CLEAN_QSS) + input_layout.addWidget(self.voice_mode_button) input_layout.addWidget(self.input_field) input_layout.addWidget(self.send_button) input_layout.addWidget(self.clear_button) + input_layout.addWidget(self.stop_button) main_layout.addLayout(input_layout) # Add the settings button as an overlay in the top-right corner @@ -670,10 +676,7 @@ def add_message( ) self.scroll_layout.addWidget(bubble) QApplication.processEvents() - # Drag scroll bar to the bottom - self.scroll_area.verticalScrollBar().setValue( - self.scroll_area.verticalScrollBar().maximum() - ) + return bubble def clear_messages(self): @@ -715,10 +718,19 @@ def start_message_task(self, *, text: str = None, audio: str = None): message_worker.update_bubble_signal.connect(self.on_update_bubble) message_worker.update_duration_signal.connect(self.on_update_duration) message_worker.update_text_signal.connect(self.on_update_text) - # worker -> QRunnable -> QThreadPool - async_task = AsyncTaskWorker(message_worker) - self.thread_pool.start(async_task) + self.async_msg_task = AsyncTaskWorker(message_worker) + self.thread_pool.start(self.async_msg_task) + pass + + def stop_message_task(self): + if self.async_msg_task: + self.async_msg_task.cancel() + self.async_msg_task = None + + def on_message_task_finished(self, audio): + self.audio_files.append(audio) + logger.info("Message Task Complete") pass def on_add_message(self, text, is_sender, is_voice, audio, duration): @@ -746,21 +758,24 @@ def on_update_text(self, text): item: MessageBubble = self.scroll_layout.itemAt(num_bubbles - 1).widget() item.update_text(text) - def on_message_task_finished(self, audio): - self.audio_files.append(audio) - logger.info("Message Task Complete") - pass - def update_bubble_size(self, mode: str = "all"): start_idx = self.scroll_layout.count() - 1 if mode == "last" else 0 for i in range(start_idx, self.scroll_layout.count(), 1): item = self.scroll_layout.itemAt(i).widget() if isinstance(item, MessageBubble): item.msg.setFixedWidth(item.get_dynamic_width(self.width())) + if mode == "last": + # Drag scroll bar to the bottom + self.scroll_area.verticalScrollBar().setValue( + self.scroll_area.verticalScrollBar().maximum() + ) pass def keyPressEvent(self, event): - if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter): + if event.modifiers() == Qt.Modifier.CTRL and event.key() == Qt.Key.Key_B: + self.stop_message_task() + event.accept() + elif event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter): if self.input_field.hasFocus(): self.send_message_text() else: @@ -820,8 +835,9 @@ def __init__( self.system_prompt = system_prompt self.system_audios = system_audios self.loop = loop + self._task = None - async def send_message_async(self): + async def send_message_async(self, cancel_event: asyncio.Event): text = self.input_text audio = self.input_audio agent = self.agent @@ -875,32 +891,84 @@ async def send_message_async(self): self.update_bubble_signal.emit("last") # Step 3: Generate audio and text segments in real-time - async def infostream_generator(): + async def wave_generator(audio_data: bytes, cancel_event: asyncio.Event): + chunk_size = 32768 # 32KB = 16K samples = 16384 / 44100 = 0.372 s + offset = 0 + + while offset + chunk_size <= len(audio_data): + # one method to stop async audioplayer is to cut off the wav-stream + if cancel_event.is_set(): + break + yield audio_data[offset : offset + chunk_size] + offset += chunk_size + + if cancel_event.is_set(): + yield b"" + elif offset < len(audio_data): + yield audio_data[offset:] + + async def infostream_generator(cancel_event: asyncio.Event): total_seg_time = 0.0 - yield wav_chunk_header() - async for event in agent.stream( - chat_ctx={"messages": self.state.conversation} - ): - if event.type == FishE2EEventType.SPEECH_SEGMENT: - self.state.append_to_chat_ctx(ServeVQPart(codes=event.vq_codes)) - total_seg_time += len(event.vq_codes[0]) / 21 - yield bytes(event.frame.data) - self.update_duration_signal.emit(total_seg_time) - elif event.type == FishE2EEventType.TEXT_SEGMENT: - self.state.append_to_chat_ctx(ServeTextPart(text=event.text)) - self.update_text_signal.emit( - self.state.repr_message(self.state.conversation[-1]) - ) - self.update_bubble_signal.emit("last") + yield wav_chunk_header() # Initial header + + try: + async for event in agent.stream( + chat_ctx={"messages": self.state.conversation} + ): + if cancel_event.is_set(): + break + + if event.type == FishE2EEventType.SPEECH_SEGMENT: + self.state.append_to_chat_ctx(ServeVQPart(codes=event.vq_codes)) + total_seg_time += len(event.vq_codes[0]) / 21 + + audio_data = bytes(event.frame.data) + async for chunk in wave_generator(audio_data, cancel_event): + yield chunk + + self.update_duration_signal.emit(total_seg_time) + + elif event.type == FishE2EEventType.TEXT_SEGMENT: + self.state.append_to_chat_ctx(ServeTextPart(text=event.text)) + self.update_text_signal.emit( + self.state.repr_message(self.state.conversation[-1]) + ) + self.update_bubble_signal.emit("last") + + except asyncio.CancelledError: + logger.warning("Infostream generator was cancelled.") + raise # Re-raise to assure interruption # Step 4: Play audio (streaming) + audio_player = AudioPlayWorker(audio_path=temp_wavfile, streaming=True) - await audio_player.run_async(infostream_generator()) + audio_player.set_chunks(infostream_generator(cancel_event)) + await audio_player.run_async() self.finished.emit(temp_wavfile) def run(self): # Run asynchronous tasks in a new event loop using asyncio.run - self.loop.run_until_complete(self.send_message_async()) + self.cancel_event = asyncio.Event() + self._task = self.loop.create_task(self.send_message_async(self.cancel_event)) + self._task.add_done_callback(self.on_task_done) + try: + self.loop.run_until_complete(self._task) + except asyncio.CancelledError: + pass # Don't show redundant error + + def cancel(self): + if self._task: + self.cancel_event.set() + self._task.cancel() + self._task = None + + def on_task_done(self, task: asyncio.Task): + if task.cancelled(): + logger.warning("Task was cancelled") + elif task.exception(): + logger.error(f"Task encountered an exception: {task.exception()}") + else: + logger.info("Task completed successfully") if __name__ == "__main__": diff --git a/fish/gui.py b/fish/gui.py index 0c0fd27..a263686 100644 --- a/fish/gui.py +++ b/fish/gui.py @@ -954,6 +954,7 @@ def start_conversion(self): **kwargs, ) self.tts_worker.finished_signal.connect(self.on_conversion_finished) + self.tts_worker.error_signal.connect(self.stop_conversion) self.tts_worker.packet_delay.connect( lambda t: self.latency_label.setText( _t("action.latency").format(latency=(t * 1000.0)) diff --git a/fish/modules/worker.py b/fish/modules/worker.py index f323329..3d59459 100644 --- a/fish/modules/worker.py +++ b/fish/modules/worker.py @@ -4,7 +4,7 @@ import time import wave from pathlib import Path -from typing import Iterator, List +from typing import AsyncIterator, Iterator, List import numpy as np import ormsgpack @@ -145,16 +145,15 @@ def __init__( self, audio_path: str, streaming: bool, - iterable_chunks: Iterator[bytes] = None, frames_per_buffer: int = 16384, ): super().__init__() self.audio_path = audio_path self.streaming = streaming - self.iterable_chunks = iterable_chunks self.frames_per_buffer = frames_per_buffer - self.is_interrupted = False + self.iterable_chunks = None + self.is_interrupted = False self.elapsed = 0 self.p = None self.stream = None @@ -162,7 +161,6 @@ def __init__( self.time_worker = TimeWorker(pause_time=0.1) self.time_worker.time_signal.connect(self.calc_elapsed) - # Sync Methods: def calc_elapsed(self, elapsed): self.elapsed = elapsed self.packet_delay.emit(elapsed) @@ -190,6 +188,8 @@ def start_audio_streaming(self): def audio_streaming(self): first_packet_time = None + if not self.iterable_chunks: + return for chunk in self.iterable_chunks: if self.is_interrupted: break @@ -203,29 +203,11 @@ def audio_streaming(self): first_packet_time = self.elapsed self.time_worker.stop() - def stop_audio_streaming(self): - if self.streaming and self.stream: - self.stream.stop_stream() - self.stream.close() - self.p.terminate() - self.f.close() - - def run(self): - self.start_audio_streaming() - if self.iterable_chunks: - self.audio_streaming() - self.stop_audio_streaming() - if not self.is_interrupted: - self.finished_signal.emit(self.audio_path) - - def stop(self): - self.is_interrupted = True - self.time_worker.stop() - - # Async Methods: - async def async_audio_streaming(self, async_chunks): + async def async_audio_streaming(self): first_packet_time = None - async for chunk in async_chunks: + if not self.iterable_chunks: + return + async for chunk in self.iterable_chunks: if self.is_interrupted: break if self.streaming: @@ -238,17 +220,44 @@ async def async_audio_streaming(self, async_chunks): first_packet_time = self.elapsed self.time_worker.stop() - async def run_async(self, async_chunks): - self.time_worker.start() + def stop_audio_streaming(self): + if self.streaming and self.stream: + self.stream.stop_stream() + self.stream.close() + self.p.terminate() + self.f.close() + logger.info("Playback Finished") + + def set_chunks(self, chunks: Iterator[bytes] | AsyncIterator[bytes] = None): + self.iterable_chunks = chunks + + def run(self): + logger.info("Sync Playback Started") self.start_audio_streaming() - await self.async_audio_streaming(async_chunks) + self.audio_streaming() self.stop_audio_streaming() if not self.is_interrupted: + logger.info("Sync Playback Finished") self.finished_signal.emit(self.audio_path) + + async def run_async(self): + logger.info("Async Playback Started") + self.start_audio_streaming() + await self.async_audio_streaming() + self.stop_audio_streaming() + if not self.is_interrupted: + logger.info("Async Playback Finished") + self.finished_signal.emit(self.audio_path) + + def stop(self): + self.is_interrupted = True self.time_worker.stop() + logger.info("Playback Stopped") class TTSWorker(AudioPlayWorker): + error_signal = pyqtSignal() + def __init__( self, ref_files: List[str], @@ -267,6 +276,7 @@ def __init__( self.text = text self.api_key = api_key self.kwargs = kwargs + self.streaming = streaming def get_pre_files(self): return [f for f in self.ref_files if not f.endswith(".lab")] @@ -316,14 +326,15 @@ def run(self): }, ) response.raise_for_status() - self.iterable_chunks = response.iter_content( - chunk_size=self.frames_per_buffer - ) + audio_chunks = response.iter_content(chunk_size=self.frames_per_buffer) + self.set_chunks(audio_chunks) super().run() except requests.RequestException as e: logger.error(f"Network request failed: {e}") + self.error_signal.emit() finally: self.stop() # Ensure the thread stops gracefully if there's an error + response.close() class AudioRecordWorker(QThread): @@ -413,3 +424,9 @@ def __init__(self, worker): def run(self): self.worker.run() + + def cancel(self): + self.worker.cancel() + + def stop(self): + self.worker.stop() diff --git a/locales/en_US.yaml b/locales/en_US.yaml index 13660fe..a617131 100644 --- a/locales/en_US.yaml +++ b/locales/en_US.yaml @@ -262,4 +262,5 @@ ChatWidget: send_btn: "Send" clear_confirm: "Are you sure to clear chat history?" clear_btn: "Clear" + stop_btn: "Stop" recording: "Recording: {dur:.1f} s" diff --git a/pyproject.toml b/pyproject.toml index 06f2cee..4336828 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fish-speech-gui" -version = "1.4.3" +version = "1.4.4" description = "fish-speech comfortable GUI" readme = "README.md" requires-python = "<3.12,>=3.10"