Skip to content

Commit

Permalink
Use a IOInfoClient factory
Browse files Browse the repository at this point in the history
  • Loading branch information
biglittlebigben committed Oct 29, 2024
1 parent 1a2d250 commit 3826ec1
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cmd/livekit-sip/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func runService(c *cli.Context) error {
return err
}

sipsrv, err := sip.NewService(conf, mon, log, psrpcClient)
sipsrv, err := sip.NewService(conf, mon, log, func(projectID string) rpc.IOInfoClient { return psrpcClient })
if err != nil {
return err
}
Expand Down
20 changes: 11 additions & 9 deletions pkg/sip/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ type Client struct {
activeCalls map[LocalTag]*outboundCall
byRemote map[RemoteTag]*outboundCall

handler Handler
ioClient rpc.IOInfoClient
handler Handler
getIOClient GetIOInfoClient
}

func NewClient(conf *config.Config, log logger.Logger, mon *stats.Monitor, ioClient rpc.IOInfoClient) *Client {
func NewClient(conf *config.Config, log logger.Logger, mon *stats.Monitor, getIOClient GetIOInfoClient) *Client {
if log == nil {
log = logger.GetLogger()
}
c := &Client{
conf: conf,
log: log,
mon: mon,
ioClient: ioClient,
getIOClient: getIOClient,
activeCalls: make(map[LocalTag]*outboundCall),
byRemote: make(map[RemoteTag]*outboundCall),
}
Expand Down Expand Up @@ -156,6 +156,8 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea
"toUser", req.CallTo,
)

ioClient := c.getIOClient(req.ProjectId)

callInfo := c.createSIPCallInfo(req)
defer func() {
switch retErr {
Expand All @@ -167,8 +169,8 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea
callInfo.Error = retErr.Error()
}

if c.ioClient != nil {
c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
if ioClient != nil {
ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
CallInfo: callInfo,
})
}
Expand Down Expand Up @@ -202,7 +204,7 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea
enabledFeatures: req.EnabledFeatures,
}
log.Infow("Creating SIP participant")
call, err := c.newCall(ctx, c.conf, log, LocalTag(req.SipCallId), roomConf, sipConf, callInfo)
call, err := c.newCall(ctx, c.conf, log, LocalTag(req.SipCallId), roomConf, sipConf, callInfo, ioClient)
if err != nil {
return nil, err
}
Expand All @@ -217,8 +219,8 @@ func (c *Client) createSIPParticipant(ctx context.Context, req *rpc.InternalCrea
callInfo.CallStatus = livekit.SIPCallStatus_SCS_DISCONNECTED
}

if c.ioClient != nil {
c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
if ioClient != nil {
ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
CallInfo: callInfo,
})
}
Expand Down
47 changes: 26 additions & 21 deletions pkg/sip/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (s *Server) handleInviteAuth(log logger.Logger, req *sip.Request, tx sip.Se
}

func (s *Server) onInvite(req *sip.Request, tx sip.ServerTransaction) {
callInfo, err := s.processInvite(req, tx)
callInfo, ioClient, err := s.processInvite(req, tx)

if callInfo != nil {
if err != nil {
Expand All @@ -130,22 +130,22 @@ func (s *Server) onInvite(req *sip.Request, tx sip.ServerTransaction) {
}
callInfo.EndedAt = time.Now().UnixNano()

if s.ioClient != nil {
s.ioClient.UpdateSIPCallState(context.Background(), &rpc.UpdateSIPCallStateRequest{
if ioClient != nil {
ioClient.UpdateSIPCallState(context.Background(), &rpc.UpdateSIPCallStateRequest{
CallInfo: callInfo,
})
}
}
}

func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*livekit.SIPCallInfo, error) {
func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*livekit.SIPCallInfo, rpc.IOInfoClient, error) {
ctx := context.Background()
s.mon.InviteReqRaw(stats.Inbound)
src, err := netip.ParseAddrPort(req.Source())
if err != nil {
tx.Terminate()
s.log.Errorw("cannot parse source IP", err, "fromIP", src)
return nil, psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP"))
return nil, nil, psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP"))
}
callID := lksip.NewCallID()
log := s.log.WithValues(
Expand All @@ -168,19 +168,13 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv
CreatedAt: time.Now().UnixNano(),
}

if s.ioClient != nil {
s.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
CallInfo: callInfo,
})
}

if err := cc.ValidateInvite(); err != nil {
if s.conf.HideInboundPort {
cc.Drop()
} else {
cc.RespondAndDrop(sip.StatusBadRequest, "Bad request")
}
return callInfo, psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed"))
return callInfo, nil, psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed"))
}
ctx, span := tracer.Start(ctx, "Server.onInvite")
defer span.End()
Expand All @@ -201,7 +195,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv
cmon.InviteErrorShort("auth-error")
log.Warnw("Rejecting inbound, auth check failed", err)
cc.RespondAndDrop(sip.StatusServiceUnavailable, "Try again later")
return callInfo, psrpc.NewError(psrpc.PermissionDenied, errors.Wrap(err, "rejecting inbound, auth check failed"))
return callInfo, nil, psrpc.NewError(psrpc.PermissionDenied, errors.Wrap(err, "rejecting inbound, auth check failed"))
}
if r.ProjectID != "" {
log = log.WithValues("projectID", r.ProjectID)
Expand All @@ -210,17 +204,25 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv
log = log.WithValues("sipTrunk", r.TrunkID)
callInfo.TrunkId = r.TrunkID
}

ioClient := s.getIOClient(r.ProjectID)
if ioClient != nil {
ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
CallInfo: callInfo,
})
}

switch r.Result {
case AuthDrop:
cmon.InviteErrorShort("flood")
log.Debugw("Dropping inbound flood")
cc.Drop()
return callInfo, psrpc.NewErrorf(psrpc.PermissionDenied, "call was not authorized by trunk configuration")
return callInfo, ioClient, psrpc.NewErrorf(psrpc.PermissionDenied, "call was not authorized by trunk configuration")
case AuthNotFound:
cmon.InviteErrorShort("no-rule")
log.Warnw("Rejecting inbound, doesn't match any Trunks", nil)
cc.RespondAndDrop(sip.StatusNotFound, "Does not match any SIP Trunks")
return callInfo, psrpc.NewErrorf(psrpc.NotFound, "no trunk configuration for call")
return callInfo, ioClient, psrpc.NewErrorf(psrpc.NotFound, "no trunk configuration for call")
case AuthPassword:
if s.conf.HideInboundPort {
// We will send password request anyway, so might as well signal that the progress is made.
Expand All @@ -229,21 +231,21 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (*liv
if !s.handleInviteAuth(log, req, tx, from.User, r.Username, r.Password) {
cmon.InviteErrorShort("unauthorized")
// handleInviteAuth will generate the SIP Response as needed
return callInfo, psrpc.NewErrorf(psrpc.PermissionDenied, "invalid crendentials were provided")
return callInfo, ioClient, psrpc.NewErrorf(psrpc.PermissionDenied, "invalid crendentials were provided")
}
fallthrough
case AuthAccept:
// ok
}

call := s.newInboundCall(log, cmon, cc, src, callInfo, nil)
call := s.newInboundCall(log, cmon, cc, src, callInfo, ioClient, nil)
call.joinDur = joinDur
err = call.handleInvite(call.ctx, req, r.TrunkID, s.conf)
if err != nil {
return callInfo, err
return callInfo, ioClient, err
}

return callInfo, nil
return callInfo, ioClient, nil
}

func (s *Server) onBye(req *sip.Request, tx sip.ServerTransaction) {
Expand Down Expand Up @@ -307,6 +309,7 @@ type inboundCall struct {
log logger.Logger
cc *sipInbound
mon *stats.CallMonitor
ioClient rpc.IOInfoClient
extraAttrs map[string]string
ctx context.Context
cancel func()
Expand All @@ -328,6 +331,7 @@ func (s *Server) newInboundCall(
cc *sipInbound,
src netip.AddrPort,
callInfo *livekit.SIPCallInfo,
ioClient rpc.IOInfoClient,
extra map[string]string,
) *inboundCall {

Expand All @@ -339,6 +343,7 @@ func (s *Server) newInboundCall(
cc: cc,
src: src,
callInfo: callInfo,
ioClient: ioClient,
extraAttrs: extra,
dtmf: make(chan dtmf.Event, 10),
lkRoom: NewRoom(log), // we need it created earlier so that the audio mixer is available for pin prompts
Expand Down Expand Up @@ -487,8 +492,8 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI

c.started.Break()

if c.s.ioClient != nil {
c.s.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
if c.ioClient != nil {
c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
CallInfo: c.callInfo,
})
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/sip/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type sipOutboundConfig struct {
type outboundCall struct {
c *Client
log logger.Logger
ioClient rpc.IOInfoClient
cc *sipOutbound
media *MediaPort
callInfo *livekit.SIPCallInfo
Expand All @@ -76,7 +77,7 @@ type outboundCall struct {
sipConf sipOutboundConfig
}

func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Logger, id LocalTag, room RoomConfig, sipConf sipOutboundConfig, callInfo *livekit.SIPCallInfo) (*outboundCall, error) {
func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Logger, id LocalTag, room RoomConfig, sipConf sipOutboundConfig, callInfo *livekit.SIPCallInfo, ioClient rpc.IOInfoClient) (*outboundCall, error) {
if sipConf.maxCallDuration <= 0 || sipConf.maxCallDuration > maxCallDuration {
sipConf.maxCallDuration = maxCallDuration
}
Expand All @@ -100,6 +101,7 @@ func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Lo
}, contact),
sipConf: sipConf,
callInfo: callInfo,
ioClient: ioClient,
}

call.mon = c.mon.NewCall(stats.Outbound, sipConf.host, sipConf.address)
Expand Down Expand Up @@ -142,8 +144,8 @@ func (c *outboundCall) Start(ctx context.Context) {
c.callInfo.StartedAt = time.Now().UnixNano()
c.callInfo.CallStatus = livekit.SIPCallStatus_SCS_ACTIVE

if c.c.ioClient != nil {
c.c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
if c.ioClient != nil {
c.ioClient.UpdateSIPCallState(context.WithoutCancel(ctx), &rpc.UpdateSIPCallStateRequest{
CallInfo: c.callInfo,
})
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/sip/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type Server struct {
log logger.Logger
mon *stats.Monitor
sipSrv *sipgo.Server
ioClient rpc.IOInfoClient
getIOClient GetIOInfoClient
sipListeners []io.Closer
sipUnhandled RequestHandler

Expand All @@ -136,15 +136,15 @@ type inProgressInvite struct {
challenge digest.Challenge
}

func NewServer(conf *config.Config, log logger.Logger, mon *stats.Monitor, ioClient rpc.IOInfoClient) *Server {
func NewServer(conf *config.Config, log logger.Logger, mon *stats.Monitor, getIOClient GetIOInfoClient) *Server {
if log == nil {
log = logger.GetLogger()
}
s := &Server{
log: log,
conf: conf,
mon: mon,
ioClient: ioClient,
getIOClient: getIOClient,
activeCalls: make(map[RemoteTag]*inboundCall),
byLocal: make(map[LocalTag]*inboundCall),
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/sip/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,18 @@ type transferKey struct {
TransferTo string
}

func NewService(conf *config.Config, mon *stats.Monitor, log logger.Logger, ioClient rpc.IOInfoClient) (*Service, error) {
type GetIOInfoClient func(projectID string) rpc.IOInfoClient

func NewService(conf *config.Config, mon *stats.Monitor, log logger.Logger, getIOClient GetIOInfoClient) (*Service, error) {
if log == nil {
log = logger.GetLogger()
}
s := &Service{
conf: conf,
log: log,
mon: mon,
cli: NewClient(conf, log, mon, ioClient),
srv: NewServer(conf, log, mon, ioClient),
cli: NewClient(conf, log, mon, getIOClient),
srv: NewServer(conf, log, mon, getIOClient),
pendingTransfers: make(map[transferKey]chan struct{}),
}
var err error
Expand Down
2 changes: 1 addition & 1 deletion pkg/sip/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func testInvite(t *testing.T, h Handler, hidden bool, from, to string, test func
SIPPort: sipPort,
SIPPortListen: sipPort,
RTPPort: rtcconfig.PortRange{Start: testPortRTPMin, End: testPortRTPMax},
}, mon, logger.GetLogger(), nil)
}, mon, logger.GetLogger(), func(projectID string) rpc.IOInfoClient { return nil })
require.NoError(t, err)
require.NotNil(t, s)
t.Cleanup(s.Stop)
Expand Down
2 changes: 1 addition & 1 deletion test/integration/sip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func runSIPServer(t testing.TB, lk *LiveKit) *SIPServer {
if err != nil {
t.Fatal(err)
}
sipsrv, err := sip.NewService(conf, mon, log, psrpcCli)
sipsrv, err := sip.NewService(conf, mon, log, func(projectID string) rpc.IOInfoClient { return psrpcCli })
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 3826ec1

Please sign in to comment.