diff --git a/cmd/livekit-sip/main.go b/cmd/livekit-sip/main.go index 7855dc1..977de9c 100644 --- a/cmd/livekit-sip/main.go +++ b/cmd/livekit-sip/main.go @@ -91,7 +91,7 @@ func runService(c *cli.Context) error { return err } - sipsrv, err := sip.NewService(conf, mon, log, psrpcClient) + sipsrv, err := sip.NewService(conf, mon, log, func(projectID string) rpc.IOInfoClient { return psrpcClient }) if err != nil { return err } diff --git a/pkg/sip/client.go b/pkg/sip/client.go index 8c616ea..10a44d9 100644 --- a/pkg/sip/client.go +++ b/pkg/sip/client.go @@ -50,11 +50,11 @@ type Client struct { activeCalls map[LocalTag]*outboundCall byRemote map[RemoteTag]*outboundCall - handler Handler - ioClient rpc.IOInfoClient + handler Handler + getIOClient GetIOInfoClient } -func NewClient(conf *config.Config, log logger.Logger, mon *stats.Monitor, ioClient rpc.IOInfoClient) *Client { +func NewClient(conf *config.Config, log logger.Logger, mon *stats.Monitor, getIOClient GetIOInfoClient) *Client { if log == nil { log = logger.GetLogger() } @@ -62,7 +62,7 @@ func NewClient(conf *config.Config, log logger.Logger, mon *stats.Monitor, ioCli conf: conf, log: log, mon: mon, - ioClient: ioClient, + getIOClient: getIOClient, activeCalls: make(map[LocalTag]*outboundCall), byRemote: make(map[RemoteTag]*outboundCall), } @@ -156,6 +156,8 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea "toUser", req.CallTo, ) + ioClient := c.getIOClient(req.ProjectId) + callInfo := c.createSIPCallInfo(req) defer func() { switch retErr { @@ -167,8 +169,8 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea callInfo.Error = retErr.Error() } - if c.ioClient != nil { - c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ + if ioClient != nil { + ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ CallInfo: callInfo, }) } @@ -202,7 +204,7 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea enabledFeatures: req.EnabledFeatures, } log.Infow("Creating SIP participant") - call, err := c.newCall(ctx, c.conf, log, LocalTag(req.SipCallId), roomConf, sipConf, callInfo) + call, err := c.newCall(ctx, c.conf, log, LocalTag(req.SipCallId), roomConf, sipConf, callInfo, ioClient) if err != nil { return nil, err } @@ -217,8 +219,8 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea callInfo.CallStatus = livekit.SIPCallStatus_SCS_DISCONNECTED } - if c.ioClient != nil { - c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ + if ioClient != nil { + ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ CallInfo: callInfo, }) } diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index 364a7c5..0135d4e 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -119,7 +119,7 @@ func (s *Server) handleInviteAuth(log logger.Logger, req *sip.Request, tx sip.Se } func (s *Server) onInvite(req *sip.Request, tx sip.ServerTransaction) { - callInfo, err := s.processInvite(req, tx) + callInfo, ioClient, err := s.processInvite(req, tx) if callInfo != nil { if err != nil { @@ -130,22 +130,22 @@ func (s *Server) onInvite(req *sip.Request, tx sip.ServerTransaction) { } callInfo.EndedAt = time.Now().UnixNano() - if s.ioClient != nil { - s.ioClient.UpdateSIPCallState(context.Background(), &rpc.UpdateSIPCallStateRequest{ + if ioClient != nil { + ioClient.UpdateSIPCallState(context.Background(), &rpc.UpdateSIPCallStateRequest{ CallInfo: callInfo, }) } } } -func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*livekit.SIPCallInfo, error) { +func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*livekit.SIPCallInfo, rpc.IOInfoClient, error) { ctx := context.Background() s.mon.InviteReqRaw(stats.Inbound) src, err := netip.ParseAddrPort(req.Source()) if err != nil { tx.Terminate() s.log.Errorw("cannot parse source IP", err, "fromIP", src) - return nil, psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP")) + return nil, nil, psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP")) } callID := lksip.NewCallID() log := s.log.WithValues( @@ -168,19 +168,13 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv CreatedAt: time.Now().UnixNano(), } - if s.ioClient != nil { - s.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ - CallInfo: callInfo, - }) - } - if err := cc.ValidateInvite(); err != nil { if s.conf.HideInboundPort { cc.Drop() } else { cc.RespondAndDrop(sip.StatusBadRequest, "Bad request") } - return callInfo, psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed")) + return callInfo, nil, psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed")) } ctx, span := tracer.Start(ctx, "Server.onInvite") defer span.End() @@ -201,7 +195,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv cmon.InviteErrorShort("auth-error") log.Warnw("Rejecting inbound, auth check failed", err) cc.RespondAndDrop(sip.StatusServiceUnavailable, "Try again later") - return callInfo, psrpc.NewError(psrpc.PermissionDenied, errors.Wrap(err, "rejecting inbound, auth check failed")) + return callInfo, nil, psrpc.NewError(psrpc.PermissionDenied, errors.Wrap(err, "rejecting inbound, auth check failed")) } if r.ProjectID != "" { log = log.WithValues("projectID", r.ProjectID) @@ -210,17 +204,25 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv log = log.WithValues("sipTrunk", r.TrunkID) callInfo.TrunkId = r.TrunkID } + + ioClient := s.getIOClient(r.ProjectID) + if ioClient != nil { + ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ + CallInfo: callInfo, + }) + } + switch r.Result { case AuthDrop: cmon.InviteErrorShort("flood") log.Debugw("Dropping inbound flood") cc.Drop() - return callInfo, psrpc.NewErrorf(psrpc.PermissionDenied, "call was not authorized by trunk configuration") + return callInfo, ioClient, psrpc.NewErrorf(psrpc.PermissionDenied, "call was not authorized by trunk configuration") case AuthNotFound: cmon.InviteErrorShort("no-rule") log.Warnw("Rejecting inbound, doesn't match any Trunks", nil) cc.RespondAndDrop(sip.StatusNotFound, "Does not match any SIP Trunks") - return callInfo, psrpc.NewErrorf(psrpc.NotFound, "no trunk configuration for call") + return callInfo, ioClient, psrpc.NewErrorf(psrpc.NotFound, "no trunk configuration for call") case AuthPassword: if s.conf.HideInboundPort { // We will send password request anyway, so might as well signal that the progress is made. @@ -229,21 +231,21 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv if !s.handleInviteAuth(log, req, tx, from.User, r.Username, r.Password) { cmon.InviteErrorShort("unauthorized") // handleInviteAuth will generate the SIP Response as needed - return callInfo, psrpc.NewErrorf(psrpc.PermissionDenied, "invalid crendentials were provided") + return callInfo, ioClient, psrpc.NewErrorf(psrpc.PermissionDenied, "invalid crendentials were provided") } fallthrough case AuthAccept: // ok } - call := s.newInboundCall(log, cmon, cc, src, callInfo, nil) + call := s.newInboundCall(log, cmon, cc, src, callInfo, ioClient, nil) call.joinDur = joinDur err = call.handleInvite(call.ctx, req, r.TrunkID, s.conf) if err != nil { - return callInfo, err + return callInfo, ioClient, err } - return callInfo, nil + return callInfo, ioClient, nil } func (s *Server) onBye(req *sip.Request, tx sip.ServerTransaction) { @@ -307,6 +309,7 @@ type inboundCall struct { log logger.Logger cc *sipInbound mon *stats.CallMonitor + ioClient rpc.IOInfoClient extraAttrs map[string]string ctx context.Context cancel func() @@ -328,6 +331,7 @@ func (s *Server) newInboundCall( cc *sipInbound, src netip.AddrPort, callInfo *livekit.SIPCallInfo, + ioClient rpc.IOInfoClient, extra map[string]string, ) *inboundCall { @@ -339,6 +343,7 @@ func (s *Server) newInboundCall( cc: cc, src: src, callInfo: callInfo, + ioClient: ioClient, extraAttrs: extra, dtmf: make(chan dtmf.Event, 10), lkRoom: NewRoom(log), // we need it created earlier so that the audio mixer is available for pin prompts @@ -487,8 +492,8 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI c.started.Break() - if c.s.ioClient != nil { - c.s.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ + if c.ioClient != nil { + c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ CallInfo: c.callInfo, }) } diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index 84b05a9..2de7218 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -62,6 +62,7 @@ type sipOutboundConfig struct { type outboundCall struct { c *Client log logger.Logger + ioClient rpc.IOInfoClient cc *sipOutbound media *MediaPort callInfo *livekit.SIPCallInfo @@ -76,7 +77,7 @@ type outboundCall struct { sipConf sipOutboundConfig } -func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Logger, id LocalTag, room RoomConfig, sipConf sipOutboundConfig, callInfo *livekit.SIPCallInfo) (*outboundCall, error) { +func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Logger, id LocalTag, room RoomConfig, sipConf sipOutboundConfig, callInfo *livekit.SIPCallInfo, ioClient rpc.IOInfoClient) (*outboundCall, error) { if sipConf.maxCallDuration <= 0 || sipConf.maxCallDuration > maxCallDuration { sipConf.maxCallDuration = maxCallDuration } @@ -100,6 +101,7 @@ func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Lo }, contact), sipConf: sipConf, callInfo: callInfo, + ioClient: ioClient, } call.mon = c.mon.NewCall(stats.Outbound, sipConf.host, sipConf.address) @@ -142,8 +144,8 @@ func (c *outboundCall) Start(ctx context.Context) { c.callInfo.StartedAt = time.Now().UnixNano() c.callInfo.CallStatus = livekit.SIPCallStatus_SCS_ACTIVE - if c.c.ioClient != nil { - c.c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ + if c.ioClient != nil { + c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{ CallInfo: c.callInfo, }) } diff --git a/pkg/sip/server.go b/pkg/sip/server.go index cd27cc5..cbb83a9 100644 --- a/pkg/sip/server.go +++ b/pkg/sip/server.go @@ -113,7 +113,7 @@ type Server struct { log logger.Logger mon *stats.Monitor sipSrv *sipgo.Server - ioClient rpc.IOInfoClient + getIOClient GetIOInfoClient sipListeners []io.Closer sipUnhandled RequestHandler @@ -136,7 +136,7 @@ type inProgressInvite struct { challenge digest.Challenge } -func NewServer(conf *config.Config, log logger.Logger, mon *stats.Monitor, ioClient rpc.IOInfoClient) *Server { +func NewServer(conf *config.Config, log logger.Logger, mon *stats.Monitor, getIOClient GetIOInfoClient) *Server { if log == nil { log = logger.GetLogger() } @@ -144,7 +144,7 @@ func NewServer(conf *config.Config, log logger.Logger, mon *stats.Monitor, ioCli log: log, conf: conf, mon: mon, - ioClient: ioClient, + getIOClient: getIOClient, activeCalls: make(map[RemoteTag]*inboundCall), byLocal: make(map[LocalTag]*inboundCall), } diff --git a/pkg/sip/service.go b/pkg/sip/service.go index 8cfc5e2..8d2ee77 100644 --- a/pkg/sip/service.go +++ b/pkg/sip/service.go @@ -58,7 +58,9 @@ type transferKey struct { TransferTo string } -func NewService(conf *config.Config, mon *stats.Monitor, log logger.Logger, ioClient rpc.IOInfoClient) (*Service, error) { +type GetIOInfoClient func(projectID string) rpc.IOInfoClient + +func NewService(conf *config.Config, mon *stats.Monitor, log logger.Logger, getIOClient GetIOInfoClient) (*Service, error) { if log == nil { log = logger.GetLogger() } @@ -66,8 +68,8 @@ func NewService(conf *config.Config, mon *stats.Monitor, log logger.Logger, ioCl conf: conf, log: log, mon: mon, - cli: NewClient(conf, log, mon, ioClient), - srv: NewServer(conf, log, mon, ioClient), + cli: NewClient(conf, log, mon, getIOClient), + srv: NewServer(conf, log, mon, getIOClient), pendingTransfers: make(map[transferKey]chan struct{}), } var err error diff --git a/pkg/sip/service_test.go b/pkg/sip/service_test.go index c143567..abfe25b 100644 --- a/pkg/sip/service_test.go +++ b/pkg/sip/service_test.go @@ -95,7 +95,7 @@ func testInvite(t *testing.T, h Handler, hidden bool, from, to string, test func SIPPort: sipPort, SIPPortListen: sipPort, RTPPort: rtcconfig.PortRange{Start: testPortRTPMin, End: testPortRTPMax}, - }, mon, logger.GetLogger(), nil) + }, mon, logger.GetLogger(), func(projectID string) rpc.IOInfoClient { return nil }) require.NoError(t, err) require.NotNil(t, s) t.Cleanup(s.Stop) diff --git a/test/integration/sip_test.go b/test/integration/sip_test.go index 9f4e71f..5a0b3c0 100644 --- a/test/integration/sip_test.go +++ b/test/integration/sip_test.go @@ -79,7 +79,7 @@ func runSIPServer(t testing.TB, lk *LiveKit) *SIPServer { if err != nil { t.Fatal(err) } - sipsrv, err := sip.NewService(conf, mon, log, psrpcCli) + sipsrv, err := sip.NewService(conf, mon, log, func(projectID string) rpc.IOInfoClient { return psrpcCli }) if err != nil { t.Fatal(err) }