diff --git a/livekit-rtc/livekit/rtc/__init__.py b/livekit-rtc/livekit/rtc/__init__.py index a13925e6..c04d96a0 100644 --- a/livekit-rtc/livekit/rtc/__init__.py +++ b/livekit-rtc/livekit/rtc/__init__.py @@ -56,7 +56,7 @@ RemoteTrackPublication, TrackPublication, ) -from .transcription import TranscriptionSegment +from .transcription import Transcription, TranscriptionSegment from .version import __version__ from .video_frame import ( VideoFrame, @@ -110,6 +110,7 @@ "LocalTrackPublication", "RemoteTrackPublication", "TrackPublication", + "Transcription", "TranscriptionSegment", "VideoFrame", "VideoSource", diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index 8523edce..7d7351f9 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -34,7 +34,7 @@ RemoteTrackPublication, TrackPublication, ) -from .transcription import TranscriptionSegment +from .transcription import Transcription class PublishTrackError(Exception): @@ -131,13 +131,7 @@ async def publish_data( if cb.publish_data.error: raise PublishDataError(cb.publish_data.error) - async def publish_transcription( - self, - participant_identity: str, - track_id: str, - segments: List[TranscriptionSegment], - language: str, - ) -> None: + async def publish_transcription(self, transcription: Transcription) -> None: req = proto_ffi.FfiRequest() proto_segments = [ ProtoTranscriptionSegment( @@ -147,13 +141,15 @@ async def publish_transcription( end_time=s.end_time, final=s.final, ) - for s in segments + for s in transcription.segments ] + # fmt: off req.publish_transcription.local_participant_handle = self._ffi_handle.handle - req.publish_transcription.participant_identity = participant_identity + req.publish_transcription.participant_identity = transcription.participant_identity req.publish_transcription.segments.extend(proto_segments) - req.publish_transcription.track_id = track_id - req.publish_transcription.language = language + req.publish_transcription.track_id = transcription.track_id + req.publish_transcription.language = transcription.language + # fmt: on queue = FfiClient.instance.queue.subscribe() try: resp = FfiClient.instance.request(req) diff --git a/livekit-rtc/livekit/rtc/transcription.py b/livekit-rtc/livekit/rtc/transcription.py index 76db4a57..3694e0d0 100644 --- a/livekit-rtc/livekit/rtc/transcription.py +++ b/livekit-rtc/livekit/rtc/transcription.py @@ -1,6 +1,15 @@ +from typing import List from dataclasses import dataclass +@dataclass +class Transcription: + participant_identity: str + track_id: str + segments: List["TranscriptionSegment"] + language: str + + @dataclass class TranscriptionSegment: id: str