Skip to content

Commit

Permalink
fix issue when before_tts_cb returns an async iterable but the tts_so…
Browse files Browse the repository at this point in the history
…urce is a str (#906)
  • Loading branch information
martin-purplefish authored Oct 15, 2024
1 parent cce8e08 commit a95dd8f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .changeset/heavy-donkeys-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

Fix bug where if the tts_source was a string but before_tts_cb returned AsyncIterable[str], the transcript would not be synthesized.
48 changes: 30 additions & 18 deletions livekit-agents/livekit/agents/pipeline/agent_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None:

if isinstance(tts_source, Awaitable):
tts_source = await tts_source
co = _str_synthesis_task(tts_source, transcript_source, handle)
elif isinstance(tts_source, str):
if isinstance(transcript_source, Awaitable):
transcript_source = await transcript_source

if isinstance(tts_source, str):
co = _str_synthesis_task(tts_source, transcript_source, handle)
else:
co = _stream_synthesis_task(tts_source, transcript_source, handle)
Expand All @@ -187,16 +189,28 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None:


@utils.log_exceptions(logger=logger)
async def _str_synthesis_task(
tts_text: str, transcript: str, handle: SynthesisHandle
async def _read_transcript_task(
transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle
) -> None:
"""synthesize speech from a string"""
if isinstance(transcript_source, str):
handle._tr_fwd.push_text(transcript_source)
else:
async for seg in transcript_source:
if not handle._tr_fwd.closed:
handle._tr_fwd.push_text(seg)

if not handle.tts_forwarder.closed:
handle.tts_forwarder.push_text(transcript)
handle.tts_forwarder.mark_text_segment_end()


@utils.log_exceptions(logger=logger)
async def _str_synthesis_task(
tts_text: str, transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle
) -> None:
"""synthesize speech from a string"""
start_time = time.time()
first_frame = True
read_transcript_atask: asyncio.Task | None = None

try:
async for audio in handle._tts.synthesize(tts_text):
Expand All @@ -210,6 +224,9 @@ async def _str_synthesis_task(
"streamed": False,
},
)
read_transcript_atask = asyncio.create_task(
_read_transcript_task(transcript_source, handle)
)

frame = audio.frame

Expand All @@ -221,11 +238,14 @@ async def _str_synthesis_task(
if not handle.tts_forwarder.closed:
handle.tts_forwarder.mark_audio_segment_end()

if read_transcript_atask is not None:
await read_transcript_atask


@utils.log_exceptions(logger=logger)
async def _stream_synthesis_task(
tts_source: AsyncIterable[str],
transcript_source: AsyncIterable[str],
transcript_source: AsyncIterable[str] | str,
handle: SynthesisHandle,
) -> None:
"""synthesize speech from streamed text"""
Expand Down Expand Up @@ -254,16 +274,6 @@ async def _read_generated_audio_task():
if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.mark_audio_segment_end()

@utils.log_exceptions(logger=logger)
async def _read_transcript_task():
async for seg in transcript_source:
if not handle._tr_fwd.closed:
handle._tr_fwd.push_text(seg)

if not handle.tts_forwarder.closed:
handle.tts_forwarder.mark_text_segment_end()

# otherwise, stream the text to the TTS
tts_stream = handle._tts.stream()
read_tts_atask: asyncio.Task | None = None
read_transcript_atask: asyncio.Task | None = None
Expand All @@ -273,7 +283,9 @@ async def _read_transcript_task():
if read_tts_atask is None:
# start the task when we receive the first text segment (so start_time is more accurate)
read_tts_atask = asyncio.create_task(_read_generated_audio_task())
read_transcript_atask = asyncio.create_task(_read_transcript_task())
read_transcript_atask = asyncio.create_task(
_read_transcript_task(transcript_source, handle)
)

tts_stream.push_text(seg)

Expand Down
6 changes: 3 additions & 3 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .._constants import ATTRIBUTE_AGENT_STATE
from .._types import AgentState
from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream
from .agent_output import AgentOutput, SynthesisHandle
from .agent_output import AgentOutput, SpeechSource, SynthesisHandle
from .agent_playout import AgentPlayout
from .human_input import HumanInput
from .log import logger
Expand All @@ -28,7 +28,7 @@

BeforeTTSCallback = Callable[
["VoicePipelineAgent", Union[str, AsyncIterable[str]]],
Union[str, AsyncIterable[str], Awaitable[str]],
SpeechSource,
]


Expand Down Expand Up @@ -784,7 +784,7 @@ def _synthesize_agent_speech(

tts_source = self._opts.before_tts_cb(self, og_source)
if tts_source is None:
logger.error("before_tts_cb must return str or AsyncIterable[str]")
raise ValueError("before_tts_cb must return str or AsyncIterable[str]")

return self._agent_output.synthesize(
speech_id=speech_id,
Expand Down

0 comments on commit a95dd8f

Please sign in to comment.