diff --git a/.changeset/heavy-donkeys-check.md b/.changeset/heavy-donkeys-check.md new file mode 100644 index 000000000..4ab1d86db --- /dev/null +++ b/.changeset/heavy-donkeys-check.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +Fix bug where if the tts_source was a string but before_tts_cb returned AsyncIterable[str], the transcript would not be synthesized. \ No newline at end of file diff --git a/livekit-agents/livekit/agents/pipeline/agent_output.py b/livekit-agents/livekit/agents/pipeline/agent_output.py index 8025859bb..f5824155f 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_output.py +++ b/livekit-agents/livekit/agents/pipeline/agent_output.py @@ -170,8 +170,10 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None: if isinstance(tts_source, Awaitable): tts_source = await tts_source - co = _str_synthesis_task(tts_source, transcript_source, handle) - elif isinstance(tts_source, str): + if isinstance(transcript_source, Awaitable): + transcript_source = await transcript_source + + if isinstance(tts_source, str): co = _str_synthesis_task(tts_source, transcript_source, handle) else: co = _stream_synthesis_task(tts_source, transcript_source, handle) @@ -187,16 +189,28 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None: @utils.log_exceptions(logger=logger) -async def _str_synthesis_task( - tts_text: str, transcript: str, handle: SynthesisHandle +async def _read_transcript_task( + transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle ) -> None: - """synthesize speech from a string""" + if isinstance(transcript_source, str): + handle._tr_fwd.push_text(transcript_source) + else: + async for seg in transcript_source: + if not handle._tr_fwd.closed: + handle._tr_fwd.push_text(seg) + if not handle.tts_forwarder.closed: - handle.tts_forwarder.push_text(transcript) handle.tts_forwarder.mark_text_segment_end() + +@utils.log_exceptions(logger=logger) +async def _str_synthesis_task( + tts_text: str, transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle +) -> None: + """synthesize speech from a string""" start_time = time.time() first_frame = True + read_transcript_atask: asyncio.Task | None = None try: async for audio in handle._tts.synthesize(tts_text): @@ -210,6 +224,9 @@ async def _str_synthesis_task( "streamed": False, }, ) + read_transcript_atask = asyncio.create_task( + _read_transcript_task(transcript_source, handle) + ) frame = audio.frame @@ -221,11 +238,14 @@ async def _str_synthesis_task( if not handle.tts_forwarder.closed: handle.tts_forwarder.mark_audio_segment_end() + if read_transcript_atask is not None: + await read_transcript_atask + @utils.log_exceptions(logger=logger) async def _stream_synthesis_task( tts_source: AsyncIterable[str], - transcript_source: AsyncIterable[str], + transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle, ) -> None: """synthesize speech from streamed text""" @@ -254,16 +274,6 @@ async def _read_generated_audio_task(): if handle._tr_fwd and not handle._tr_fwd.closed: handle._tr_fwd.mark_audio_segment_end() - @utils.log_exceptions(logger=logger) - async def _read_transcript_task(): - async for seg in transcript_source: - if not handle._tr_fwd.closed: - handle._tr_fwd.push_text(seg) - - if not handle.tts_forwarder.closed: - handle.tts_forwarder.mark_text_segment_end() - - # otherwise, stream the text to the TTS tts_stream = handle._tts.stream() read_tts_atask: asyncio.Task | None = None read_transcript_atask: asyncio.Task | None = None @@ -273,7 +283,9 @@ async def _read_transcript_task(): if read_tts_atask is None: # start the task when we receive the first text segment (so start_time is more accurate) read_tts_atask = asyncio.create_task(_read_generated_audio_task()) - read_transcript_atask = asyncio.create_task(_read_transcript_task()) + read_transcript_atask = asyncio.create_task( + _read_transcript_task(transcript_source, handle) + ) tts_stream.push_text(seg) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 06fcc2ff5..82c8e40b6 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -12,7 +12,7 @@ from .._constants import ATTRIBUTE_AGENT_STATE from .._types import AgentState from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream -from .agent_output import AgentOutput, SynthesisHandle +from .agent_output import AgentOutput, SpeechSource, SynthesisHandle from .agent_playout import AgentPlayout from .human_input import HumanInput from .log import logger @@ -28,7 +28,7 @@ BeforeTTSCallback = Callable[ ["VoicePipelineAgent", Union[str, AsyncIterable[str]]], - Union[str, AsyncIterable[str], Awaitable[str]], + SpeechSource, ] @@ -784,7 +784,7 @@ def _synthesize_agent_speech( tts_source = self._opts.before_tts_cb(self, og_source) if tts_source is None: - logger.error("before_tts_cb must return str or AsyncIterable[str]") + raise ValueError("before_tts_cb must return str or AsyncIterable[str]") return self._agent_output.synthesize( speech_id=speech_id,