diff --git a/callback.go b/callback.go index a1ac15f..4e3a515 100644 --- a/callback.go +++ b/callback.go @@ -46,6 +46,7 @@ type ParticipantCallback struct { OnTrackUnpublished func(publication *RemoteTrackPublication, rp *RemoteParticipant) OnDataReceived func(data []byte, params DataReceiveParams) // Deprecated: Use OnDataPacket instead OnDataPacket func(data DataPacket, params DataReceiveParams) + OnTranscriptionReceived func(transcriptionSegments []*TranscriptionSegment, p Participant, publication *RemoteTrackPublication) } func NewParticipantCallback() *ParticipantCallback { @@ -66,6 +67,8 @@ func NewParticipantCallback() *ParticipantCallback { OnTrackUnpublished: func(publication *RemoteTrackPublication, rp *RemoteParticipant) {}, OnDataReceived: func(data []byte, params DataReceiveParams) {}, OnDataPacket: func(data DataPacket, params DataReceiveParams) {}, + OnTranscriptionReceived: func(transcriptionSegments []*TranscriptionSegment, p Participant, publication *RemoteTrackPublication) { + }, } } @@ -115,6 +118,9 @@ func (cb *ParticipantCallback) Merge(other *ParticipantCallback) { if other.OnDataPacket != nil { cb.OnDataPacket = other.OnDataPacket } + if other.OnTranscriptionReceived != nil { + cb.OnTranscriptionReceived = other.OnTranscriptionReceived + } } type DisconnectionReason string diff --git a/engine.go b/engine.go index e8c4014..d9d6cfc 100644 --- a/engine.go +++ b/engine.go @@ -81,6 +81,7 @@ type RTCEngine struct { OnRestarted func(*livekit.JoinResponse) OnResuming func() OnResumed func() + OnTranscription func(*livekit.Transcription) // callbacks to get data CbGetLocalParticipantSID func() string @@ -561,6 +562,10 @@ func (e *RTCEngine) handleDataPacket(msg webrtc.DataChannelMessage) { if e.OnDataPacket != nil { e.OnDataPacket(identity, msg.SipDtmf) } + case *livekit.DataPacket_Transcription: + if e.OnTranscription != nil { + e.OnTranscription(msg.Transcription) + } } } diff --git a/room.go b/room.go index 6e04ced..ea3694c 100644 --- a/room.go +++ b/room.go @@ -199,6 +199,7 @@ func NewRoom(callback *RoomCallback) *Room { engine.OnResumed = r.handleResumed engine.OnLocalTrackUnpublished = r.handleLocalTrackUnpublished engine.OnTrackRemoteMuted = r.handleTrackRemoteMuted + engine.OnTranscription = r.handleTranscriptionReceived // callbacks engine can use to get data engine.CbGetLocalParticipantSID = r.getLocalParticipantSID @@ -715,6 +716,19 @@ func (r *Room) handleLocalTrackUnpublished(msg *livekit.TrackUnpublishedResponse } } +func (r *Room) handleTranscriptionReceived(transcription *livekit.Transcription) { + // find the participant + if transcription.TranscribedParticipantIdentity == r.LocalParticipant.Identity() { + // if sent by itself, do not handle data + return + } + p := r.GetParticipantByIdentity(transcription.TranscribedParticipantIdentity) + publication := p.getPublication(transcription.TrackId) + transcriptionSegments := ExtractTranscriptionSegments(transcription) + + r.callback.OnTranscriptionReceived(transcriptionSegments, p, publication) +} + func (r *Room) sendSyncState() { subscriber, ok := r.engine.Subscriber() if !ok || subscriber.pc.RemoteDescription() == nil { diff --git a/transcription.go b/transcription.go new file mode 100644 index 0000000..a1bf631 --- /dev/null +++ b/transcription.go @@ -0,0 +1,47 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lksdk + +import ( + "github.com/livekit/protocol/livekit" +) + +type TranscriptionSegment struct { + ID string + Text string + Language string + StartTime uint64 + EndTime uint64 + Final bool +} + +func ExtractTranscriptionSegments(transcription *livekit.Transcription) []*TranscriptionSegment { + var segments []*TranscriptionSegment + if transcription == nil { + return segments + } + segments = make([]*TranscriptionSegment, len(transcription.Segments)) + for i := range transcription.Segments { + segments[i] = &TranscriptionSegment{ + ID: transcription.Segments[i].Id, + Text: transcription.Segments[i].Text, + Language: transcription.Segments[i].Language, + StartTime: transcription.Segments[i].StartTime, + EndTime: transcription.Segments[i].EndTime, + Final: transcription.Segments[i].Final, + } + } + return segments +}