Skip to content

Commit

Permalink
Merge branch 'master' into auth_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
iSchluff committed Dec 26, 2024
2 parents 954c410 + 8399e70 commit dbe65e5
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 64 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
srtrelay
config.toml
/.vscode
# VSCode
/.vscode
# Jetbrains
/.idea
*.iml
3 changes: 2 additions & 1 deletion api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
"errors"
"log"
"net/http"
"sync"
Expand Down Expand Up @@ -49,7 +50,7 @@ func (s *Server) Listen(ctx context.Context) error {
go func() {
defer s.done.Done()
err := serv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Println(err)
}
}()
Expand Down
25 changes: 12 additions & 13 deletions auth/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)

var (
requestDurations = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: metrics.Namespace,
Subsystem: "auth",
Name: "request_duration_seconds",
Help: "A histogram of auth http request latencies.",
Buckets: prometheus.DefBuckets,
NativeHistogramBucketFactor: 1.1,
},
[]string{"url", "application"},
)
var requestDurations = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: metrics.Namespace,
Subsystem: "auth",
Name: "request_duration_seconds",
Help: "A histogram of auth http request latencies.",
Buckets: prometheus.DefBuckets,
NativeHistogramBucketFactor: 1.1,
},
[]string{"url", "application"},
)

type httpAuth struct {
Expand All @@ -52,7 +50,7 @@ type HTTPAuthConfig struct {
}

// NewHttpAuth creates an Authenticator with a HTTP backend
func NewHTTPAuth(authConfig HTTPAuthConfig) *httpAuth {
func NewHTTPAuth(authConfig HTTPAuthConfig) Authenticator {
m := requestDurations.MustCurryWith(prometheus.Labels{"url": authConfig.URL, "application": authConfig.Application})
return &httpAuth{
config: authConfig,
Expand All @@ -74,6 +72,7 @@ func (h *httpAuth) Authenticate(streamid stream.StreamID) bool {
"call": {streamid.Mode().String()},
"app": {h.config.Application},
"name": {streamid.Name()},
"username": {streamid.Username()},
h.config.PasswordParam: {streamid.Password()},
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions srt/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ func (s *ServerImpl) registerForStats(ctx context.Context, conn *srtConn) {

func (s *ServerImpl) GetStatistics() []*relay.StreamStatistics {
streams := s.relay.GetStatistics()
for _, stream := range streams {
stream.URL = fmt.Sprintf("srt://%s?streamid=play/%s", s.config.PublicAddress, stream.Name)
for _, st := range streams {
st.URL = fmt.Sprintf("srt://%s?streamid=#!::m=request,r=%s", s.config.PublicAddress, st.Name) // New format
}
return streams
}
Expand Down
6 changes: 3 additions & 3 deletions srt/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestServerImpl_GetStatistics(t *testing.T) {
streams := s.GetStatistics()

expected := []*relay.StreamStatistics{
{Name: "s1", URL: "srt://testserver.de:1337?streamid=play/s1", Clients: 2, Created: streams[0].Created},
{Name: "s1", URL: "srt://testserver.de:1337?streamid=#!::m=request,r=s1", Clients: 2, Created: streams[0].Created}, // New Format
}
if err := compareStats(streams, expected); err != nil {
t.Error(err)
Expand All @@ -65,8 +65,8 @@ func (s *testSocket) Read(b []byte) (int, error) {
if !ok {
return 0, io.EOF
}
len := copy(b, buf)
return len, nil
length := copy(b, buf)
return length, nil
}

func (s *testSocket) Write(b []byte) (int, error) {
Expand Down
96 changes: 71 additions & 25 deletions stream/streamid.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ import (
"github.com/IGLOU-EU/go-wildcard/v2"
)

const IDPrefix = "#!::"

var (
InvalidSlashes = errors.New("Invalid number of slashes, must be 1 or 2")
InvalidMode = errors.New("Invalid mode")
MissingName = errors.New("Missing name after slash")
InvalidNamePassword = errors.New("Name/Password is not allowed to contain slashes")
InvalidValue = fmt.Errorf("Invalid value")
)

// Mode - client mode
Expand Down Expand Up @@ -41,9 +44,10 @@ type StreamID struct {
mode Mode
name string
password string
username string
}

// Creates new StreamID
// NewStreamID creates new StreamID
// returns error if mode is invalid.
// id is nil on error
func NewStreamID(name string, password string, mode Mode) (*StreamID, error) {
Expand All @@ -61,39 +65,77 @@ func NewStreamID(name string, password string, mode Mode) (*StreamID, error) {
}

// FromString reads a streamid from a string.
// The accepted stream id format is <mode>/<password>/<password>.
// The second slash and password is optional and defaults to empty.
// The accepted old stream id format is <mode>/<password>/<password>. The second slash and password is
// optional and defaults to empty. The new format is `#!::m=(request|publish),r=(stream-key),u=(username),s=(password)`
// If error is not nil then StreamID will remain unchanged.
func (s *StreamID) FromString(src string) error {
split := strings.Split(src, "/")

password := ""
if len(split) == 3 {
password = split[2]
} else if len(split) != 2 {
return InvalidSlashes
}
modeStr := split[0]
name := split[1]
if strings.HasPrefix(src, IDPrefix) {
for _, kv := range strings.Split(src[len(IDPrefix):], ",") {
kv2 := strings.SplitN(kv, "=", 2)
if len(kv2) != 2 {
return InvalidValue
}

if len(name) == 0 {
return MissingName
key, value := kv2[0], kv2[1]

switch key {
case "u":
s.username = value

case "r":
s.name = value

case "h":

case "s":
s.password = value

case "t":

case "m":
switch value {
case "request":
s.mode = ModePlay

case "publish":
s.mode = ModePublish

default:
return InvalidMode
}

default:
return fmt.Errorf("unsupported key '%s'", key)
}
}
} else {
split := strings.Split(src, "/")

s.password = ""
if len(split) == 3 {
s.password = split[2]
} else if len(split) != 2 {
return InvalidSlashes
}
modeStr := split[0]
s.name = split[1]

switch modeStr {
case "play":
s.mode = ModePlay
case "publish":
s.mode = ModePublish
default:
return InvalidMode
}
}

var mode Mode
switch modeStr {
case "play":
mode = ModePlay
case "publish":
mode = ModePublish
default:
return InvalidMode
if len(s.name) == 0 {
return MissingName
}

s.str = src
s.mode = mode
s.name = name
s.password = password
return nil
}

Expand Down Expand Up @@ -140,3 +182,7 @@ func (s StreamID) Name() string {
func (s StreamID) Password() string {
return s.password
}

func (s StreamID) Username() string {
return s.username
}
77 changes: 58 additions & 19 deletions stream/streamid_test.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,72 @@
package stream

import (
"errors"
"fmt"
"testing"
)

func TestParseStreamID(t *testing.T) {
tests := []struct {
name string
streamID string
wantMode Mode
wantName string
wantPass string
wantErr error
name string
streamID string
wantMode Mode
wantName string
wantPass string
wantUsername string
wantErr error
}{
{"MissingSlash", "s1", 0, "", "", InvalidSlashes},
{"InvalidName", "play//s1", 0, "", "", MissingName},
{"InvalidMode", "foobar/bla", 0, "", "", InvalidMode},
{"InvalidSlash", "foobar/bla//", 0, "", "", InvalidSlashes},
{"EmptyPass", "play/s1/", ModePlay, "s1", "", nil},
{"ValidPass", "play/s1/#![äöü", ModePlay, "s1", "#![äöü", nil},
{"ValidPlay", "play/s1", ModePlay, "s1", "", nil},
{"ValidPublish", "publish/abcdef", ModePublish, "abcdef", "", nil},
{"ValidPlaySpace", "play/bla fasel", ModePlay, "bla fasel", "", nil},
// Old school
{"MissingSlash", "s1", 0, "", "", "", InvalidSlashes},
{"InvalidName", "play//s1", 0, "", "", "", MissingName},
{"InvalidMode", "foobar/bla", 0, "", "", "", InvalidMode},
{"InvalidSlash", "foobar/bla//", 0, "", "", "", InvalidSlashes},
{"EmptyPass", "play/s1/", ModePlay, "s1", "", "", nil},
{"ValidPass", "play/s1/#![äöü", ModePlay, "s1", "#![äöü", "", nil},
{"ValidPlay", "play/s1", ModePlay, "s1", "", "", nil},
{"ValidPublish", "publish/abcdef", ModePublish, "abcdef", "", "", nil},
{"ValidPlaySpace", "play/bla fasel", ModePlay, "bla fasel", "", "", nil},
// New hotness - Bad
{"NewInvalidPubEmptyName", "#!::m=publish", ModePublish, "", "", "", MissingName},
{"NewInvalidPlayEmptyName", "#!::m=request", ModePlay, "", "", "", MissingName},
{"NewInvalidPubBadKey", "#!::m=publish,y=bar", ModePublish, "", "", "", fmt.Errorf("unsupported key '%s'", "y")},
{"NewInvalidPlayBadKey", "#!::m=request,x=foo", ModePlay, "", "", "", fmt.Errorf("unsupported key '%s'", "x")},
{"NewInvalidPubNoEquals", "#!::m=publish,r", ModePublish, "abc", "", "", InvalidValue},
{"NewInvalidPlayNoEquals", "#!::m=request,r", ModePlay, "abc", "", "", InvalidValue},
{"NewInvalidPubNoValue", "#!::m=publish,r=", ModePublish, "abc", "", "", MissingName},
{"NewInvalidPlayNoValue", "#!::m=request,s=", ModePlay, "abc", "", "", MissingName},
{"NewInvalidPubBadKey", "#!::m=publish,x=", ModePublish, "abc", "", "", fmt.Errorf("unsupported key '%s'", "x")},
{"NewInvalidPlayBadKey", "#!::m=request,y=", ModePlay, "abc", "", "", fmt.Errorf("unsupported key '%s'", "y")},
// New hotness - Standard
{"NewValidNameRequest", "#!::m=publish,r=abc", ModePublish, "abc", "", "", nil},
{"NewValidPlay", "#!::m=request,r=abc", ModePlay, "abc", "", "", nil},
{"NewValidNameRequestRev", "#!::r=abc,m=publish", ModePublish, "abc", "", "", nil},
{"NewValidPlayRev", "#!::r=abc,m=request", ModePlay, "abc", "", "", nil},
{"NewValidPassPub", "#!::m=publish,r=abc,s=bob", ModePublish, "abc", "bob", "", nil},
{"NewValidPassPlay", "#!::m=request,r=abc,s=alice", ModePlay, "abc", "alice", "", nil},
{"NewValidPassPubOrder", "#!::s=bob,m=publish,r=abc123", ModePublish, "abc123", "bob", "", nil},
{"NewValidPassPlayOrder", "#!::m=request,s=alice,r=def", ModePlay, "def", "alice", "", nil},
{"NewValidPubUsername", "#!::s=bob,m=publish,r=abc123,u=eve", ModePublish, "abc123", "bob", "eve", nil},
{"NewValidPlayUsername", "#!::m=request,s=alice,r=def,u=bar", ModePlay, "def", "alice", "bar", nil},
{"NewValidPubUsernameOrder", "#!::s=bob,m=publish,u=eve,r=abc123", ModePublish, "abc123", "bob", "eve", nil},
{"NewValidPlayUsernameOrder", "#!::m=request,u=bar,s=alice,r=def", ModePlay, "def", "alice", "bar", nil},
// New Hotness - Unicode
{"NewValidUnicodePub", "#!::m=publish,r=#![äöü,s=bob", ModePublish, "#![äöü", "bob", "", nil},
{"NewValidUnicodePlay", "#!::m=request,r=#![äöü,s=alice", ModePlay, "#![äöü", "alice", "", nil},
{"NewValidUnicodePassPub", "#!::m=publish,s=#![äöü,r=bob", ModePublish, "bob", "#![äöü", "", nil},
{"NewValidUnicodePassPlay", "#!::m=request,s=#![äöü,r=alice", ModePlay, "alice", "#![äöü", "", nil},
{"NewValidUnicodeUserPub", "#!::s=bye,m=publish,u=#![äöü,r=art", ModePublish, "art", "bye", "#![äöü", nil},
{"NewValidUnicodeUserPlay", "#!::m=request,u=#![äöü,r=eve,s=hai", ModePlay, "eve", "hai", "#![äöü", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var streamid StreamID
err := streamid.FromString(tt.streamID)
if err != tt.wantErr {
t.Errorf("ParseStreamID() error = %v, wantErr %v", err, tt.wantErr)
}

if err != nil {
if err.Error() != tt.wantErr.Error() { // Only really care about str value for this, otherwise: if !errors.Is(err, tt.wantErr) {
t.Errorf("ParseStreamID() error = %v, wantErr %v", err, tt.wantErr)
}
if streamid.String() != "" {
t.Error("str should be empty on failed parse")
}
Expand All @@ -48,6 +84,9 @@ func TestParseStreamID(t *testing.T) {
if str := streamid.String(); str != tt.streamID {
t.Errorf("String() got String = %v, want %v", str, tt.streamID)
}
if str := streamid.Username(); str != tt.wantUsername {
t.Errorf("Username() got String = %v, want %v", str, tt.wantUsername)
}
})
}
}
Expand All @@ -72,7 +111,7 @@ func TestNewStreamID(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, err := NewStreamID(tt.argName, tt.argPassword, tt.argMode)
if err != tt.wantErr {
if !errors.Is(err, tt.wantErr) {
t.Errorf("ParseStreamID() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil {
Expand Down

0 comments on commit dbe65e5

Please sign in to comment.