Skip to content

Commit

Permalink
Improvements to end of turn plugin (#1195)
Browse files Browse the repository at this point in the history
Co-authored-by: jeradf <[email protected]>
Co-authored-by: Long Chen <[email protected]>
Co-authored-by: Jayesh Parmar <[email protected]>
  • Loading branch information
4 people authored Dec 10, 2024
1 parent 6b7f21b commit 6b4e903
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 31 deletions.
8 changes: 8 additions & 0 deletions .changeset/strange-apes-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"livekit-plugins-azure": minor
"livekit-plugins-turn-detector": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
---

Improvements to end of turn plugin, ensure STT language settings.
5 changes: 4 additions & 1 deletion examples/voice-pipeline-agent/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ export OPENAI_API_KEY=<your OpenAI API key>

### Install requirments:

`pip install -r requirements.txt`
```
pip install -r requirements.txt
python minimal_assistant.py download-files
```

### Run the agent worker:

Expand Down
35 changes: 22 additions & 13 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,12 @@ class AgentTranscriptionOptions:
representing the hyphenated parts of the word."""


class _EOUModel(Protocol):
async def predict_eou(self, chat_ctx: ChatContext) -> float: ...
class _TurnDetector(Protocol):
# When endpoint probability is below this threshold we think the user is not finished speaking
# so we will use a long delay
def unlikely_threshold(self) -> float: ...
def supports_language(self, language: str | None) -> bool: ...
async def predict_end_of_turn(self, chat_ctx: ChatContext) -> float: ...


class VoicePipelineAgent(utils.EventEmitter[EventTypes]):
Expand All @@ -179,7 +183,7 @@ def __init__(
stt: stt.STT,
llm: LLM,
tts: tts.TTS,
turn_detector: _EOUModel | None = None,
turn_detector: _TurnDetector | None = None,
chat_ctx: ChatContext | None = None,
fnc_ctx: FunctionContext | None = None,
allow_interruptions: bool = True,
Expand Down Expand Up @@ -595,7 +599,9 @@ def _on_final_transcript(ev: stt.SpeechEvent) -> None:
):
self._synthesize_agent_reply()

self._deferred_validation.on_human_final_transcript(new_transcript)
self._deferred_validation.on_human_final_transcript(
new_transcript, ev.alternatives[0].language
)

words = self._opts.transcription.word_tokenizer.tokenize(
text=new_transcript
Expand Down Expand Up @@ -1105,24 +1111,22 @@ class _DeferredReplyValidation:

LATE_TRANSCRIPT_TOLERANCE = 1.5 # late compared to end of speech

# When endpoint probability is below this threshold we think the user is not finished speaking
# so we will use a long delay
UNLIKELY_ENDPOINT_THRESHOLD = 0.15

# Long delay to use when the model thinks the user is still speaking
# TODO: make this configurable
UNLIKELY_ENDPOINT_DELAY = 6

def __init__(
self,
validate_fnc: Callable[[], None],
min_endpointing_delay: float,
turn_detector: _EOUModel | None,
turn_detector: _TurnDetector | None,
agent: VoicePipelineAgent,
) -> None:
self._turn_detector = turn_detector
self._validate_fnc = validate_fnc
self._validating_task: asyncio.Task | None = None
self._last_final_transcript: str = ""
self._last_language: str | None = None
self._last_recv_end_of_speech_time: float = 0.0
self._speaking = False

Expand All @@ -1134,8 +1138,9 @@ def __init__(
def validating(self) -> bool:
return self._validating_task is not None and not self._validating_task.done()

def on_human_final_transcript(self, transcript: str) -> None:
def on_human_final_transcript(self, transcript: str, language: str | None) -> None:
self._last_final_transcript += " " + transcript.strip() # type: ignore
self._last_language = language

if self._speaking:
return
Expand Down Expand Up @@ -1193,9 +1198,13 @@ def _run(self, delay: float) -> None:
@utils.log_exceptions(logger=logger)
async def _run_task(chat_ctx: ChatContext, delay: float) -> None:
await asyncio.sleep(delay)
if self._turn_detector is not None:
eou_prob = await self._turn_detector.predict_eou(chat_ctx)
if eou_prob < self.UNLIKELY_ENDPOINT_THRESHOLD:
if (
self._turn_detector is not None
and self._turn_detector.supports_language(self._last_language)
):
eot_prob = await self._turn_detector.predict_end_of_turn(chat_ctx)
unlikely_threshold = self._turn_detector.unlikely_threshold()
if eot_prob < unlikely_threshold:
await asyncio.sleep(self.UNLIKELY_ENDPOINT_DELAY)

self._reset_states()
Expand Down
47 changes: 37 additions & 10 deletions livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import os
import weakref
from copy import deepcopy
from dataclasses import dataclass

from livekit import rtc
Expand Down Expand Up @@ -55,7 +56,10 @@ def __init__(
segmentation_silence_timeout_ms: int | None = None,
segmentation_max_time_ms: int | None = None,
segmentation_strategy: str | None = None,
languages: list[str] = [], # when empty, auto-detect the language
# Azure handles multiple languages and can auto-detect the language used. It requires the candidate set to be set.
languages: list[str] = ["en-US"],
# for compatibility with other STT plugins
language: str | None = None,
):
"""
Create a new instance of Azure STT.
Expand Down Expand Up @@ -83,6 +87,9 @@ def __init__(
"AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION must be set"
)

if language:
languages = [language]

self._config = STTOptions(
speech_key=speech_key,
speech_region=speech_region,
Expand All @@ -109,18 +116,28 @@ async def _recognize_impl(
def stream(
self,
*,
languages: list[str] | None = None,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
stream = SpeechStream(stt=self, opts=self._config, conn_options=conn_options)
config = deepcopy(self._config)
if language and not languages:
languages = [language]
if languages:
config.languages = languages
stream = SpeechStream(stt=self, opts=config, conn_options=conn_options)
self._streams.add(stream)
return stream

def update_options(self, *, language: str | None = None):
if language is not None:
self._config.languages = [language]
def update_options(
self, *, language: str | None = None, languages: list[str] | None = None
):
if language and not languages:
languages = [language]
if languages is not None:
self._config.languages = languages
for stream in self._streams:
stream.update_options(language=language)
stream.update_options(languages=languages)


class SpeechStream(stt.SpeechStream):
Expand All @@ -139,9 +156,13 @@ def __init__(
self._loop = asyncio.get_running_loop()
self._reconnect_event = asyncio.Event()

def update_options(self, *, language: str | None = None):
if language:
self._opts.languages = [language]
def update_options(
self, *, language: str | None = None, languages: list[str] | None = None
):
if language and not languages:
languages = [language]
if languages:
self._opts.languages = languages
self._reconnect_event.set()

async def _run(self) -> None:
Expand Down Expand Up @@ -206,6 +227,9 @@ def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs):
if not text:
return

if not detected_lg and self._opts.languages:
detected_lg = self._opts.languages[0]

final_data = stt.SpeechData(
language=detected_lg, confidence=1.0, text=evt.result.text
)
Expand All @@ -224,6 +248,9 @@ def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs):
if not text:
return

if not detected_lg and self._opts.languages:
detected_lg = self._opts.languages[0]

interim_data = stt.SpeechData(
language=detected_lg, confidence=0.0, text=evt.result.text
)
Expand Down Expand Up @@ -303,7 +330,7 @@ def _create_speech_recognizer(
)

auto_detect_source_language_config = None
if config.languages:
if config.languages and len(config.languages) > 1:
auto_detect_source_language_config = (
speechsdk.languageconfig.AutoDetectSourceLanguageConfig(
languages=config.languages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,18 @@ async def _recognize_impl(
),
model=self._opts.model,
language=config.language,
response_format="json",
# verbose_json returns language and other details
response_format="verbose_json",
timeout=httpx.Timeout(30, connect=conn_options.timeout),
)

return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(text=resp.text or "", language=language or "")
stt.SpeechData(
text=resp.text or "",
language=resp.language or config.language or "",
)
],
)

Expand Down
46 changes: 46 additions & 0 deletions livekit-plugins/livekit-plugins-turn-detector/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,48 @@
# LiveKit Plugins Turn Detector

This plugin introduces end-of-turn detection for LiveKit Agents using a custom open-weight model to determine when a user has finished speaking.

Traditional voice agents use VAD (voice activity detection) for end-of-turn detection. However, VAD models lack language understanding, often causing false positives where the agent interrupts the user before they finish speaking.

By leveraging a language model specifically trained for this task, this plugin offers a more accurate and robust method for detecting end-of-turns. The current version supports English only and should not be used when targeting other languages.

## Installation

```bash
pip install livekit-plugins-turn-detector
```

## Usage

This plugin is designed to be used with the `VoicePipelineAgent`:

```python
from livekit.plugins import turn_detector

agent = VoicePipelineAgent(
...
turn_detector=turn_detector.EOUModel(),
)
```

## Running your agent

This plugin requires model files. Before starting your agent for the first time, or when building Docker images for deployment, run the following command to download the model files:

```bash
python my_agent.py download-files
```

## Model system requirements

The end-of-turn model is optimized to run on CPUs with modest system requirements. It is designed to run on the same server hosting your agents. On a 4-core server instance, it completes inference in under 100ms with minimal CPU usage.

The model requires 1.5GB of RAM and runs within a shared inference server, supporting multiple concurrent sessions.

We are working to reduce the CPU and memory requirements in future releases.

## License

The plugin source code is licensed under the Apache-2.0 license.

The end-of-turn model is licensed under the [LiveKit Model License](https://huggingface.co/livekit/turn-detector/blob/main/LICENSE).
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .log import logger

HG_MODEL = "livekit/opt-125m-endpoint-detector-2"
HG_MODEL = "livekit/turn-detector"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4

Expand Down Expand Up @@ -113,12 +113,30 @@ def run(self, data: bytes) -> bytes | None:


class EOUModel:
def __init__(self, inference_executor: InferenceExecutor | None = None) -> None:
def __init__(
self,
inference_executor: InferenceExecutor | None = None,
unlikely_threshold: float = 0.15,
) -> None:
self._executor = (
inference_executor or get_current_job_context().inference_executor
)
self._unlikely_threshold = unlikely_threshold

def unlikely_threshold(self) -> float:
return self._unlikely_threshold

def supports_language(self, language: str | None) -> bool:
if language is None:
return False
parts = language.lower().split("-")
# certain models use language codes (DG, AssemblyAI), others use full names (like OAI)
return parts[0] == "en" or parts[0] == "english"

async def predict_eou(self, chat_ctx: llm.ChatContext) -> float:
return await self.predict_end_of_turn(chat_ctx)

async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float:
messages = []

for msg in chat_ctx.messages:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import logging

logger = logging.getLogger("livekit.plugins.eou")
logger = logging.getLogger("livekit.plugins.turn_detector")
9 changes: 7 additions & 2 deletions livekit-plugins/livekit-plugins-turn-detector/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.11", "transformers>=4.46", "numpy>=1.26"],
package_data={"livekit.plugins.eou": ["py.typed"]},
install_requires=[
"livekit-agents>=0.11",
"transformers>=4.46",
"numpy>=1.26",
"torch>=2.0",
],
package_data={"livekit.plugins.turn_detector": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
Expand Down
4 changes: 4 additions & 0 deletions tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ async def _stream_output():

if event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT:
text += event.alternatives[0].text
# ensure STT is tagging languages correctly
language = event.alternatives[0].language
assert language is not None
assert language.lower().startswith("en")

if event.type == agents.stt.SpeechEventType.END_OF_SPEECH:
recv_start = False
Expand Down

0 comments on commit 6b4e903

Please sign in to comment.