diff --git a/allowed_ips_test.go b/allowed_ips_test.go index 55884c33..da4f49b7 100644 --- a/allowed_ips_test.go +++ b/allowed_ips_test.go @@ -24,19 +24,17 @@ package signaling import ( "net" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAllowedIps(t *testing.T) { + require := require.New(t) a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24") - if err != nil { - t.Fatal(err) - } - if a.Empty() { - t.Fatal("should not be empty") - } - if expected := `[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`; a.String() != expected { - t.Errorf("expected %s, got %s", expected, a.String()) - } + require.NoError(err) + require.False(a.Empty()) + require.Equal(`[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`, a.String()) allowed := []string{ "127.0.0.1", @@ -51,22 +49,18 @@ func TestAllowedIps(t *testing.T) { for _, addr := range allowed { t.Run(addr, func(t *testing.T) { - ip := net.ParseIP(addr) - if ip == nil { - t.Errorf("error parsing %s", addr) - } else if !a.Allowed(ip) { - t.Errorf("should allow %s", addr) + assert := assert.New(t) + if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) { + assert.True(a.Allowed(ip), "should allow %s", addr) } }) } for _, addr := range notAllowed { t.Run(addr, func(t *testing.T) { - ip := net.ParseIP(addr) - if ip == nil { - t.Errorf("error parsing %s", addr) - } else if a.Allowed(ip) { - t.Errorf("should not allow %s", addr) + assert := assert.New(t) + if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) { + assert.False(a.Allowed(ip), "should not allow %s", addr) } }) } diff --git a/api_backend_test.go b/api_backend_test.go index bcb32139..724075d9 100644 --- a/api_backend_test.go +++ b/api_backend_test.go @@ -24,42 +24,36 @@ package signaling import ( "net/http" "testing" + + "github.com/stretchr/testify/assert" ) func TestBackendChecksum(t *testing.T) { t.Parallel() + assert := assert.New(t) rnd := newRandomString(32) body := []byte{1, 2, 3, 4, 5} secret := []byte("shared-secret") check1 := CalculateBackendChecksum(rnd, body, secret) check2 := CalculateBackendChecksum(rnd, body, secret) - if check1 != check2 { - t.Errorf("Expected equal checksums, got %s and %s", check1, check2) - } + assert.Equal(check1, check2, "Expected equal checksums") - if !ValidateBackendChecksumValue(check1, rnd, body, secret) { - t.Errorf("Checksum %s could not be validated", check1) - } - if ValidateBackendChecksumValue(check1[1:], rnd, body, secret) { - t.Errorf("Checksum %s should not be valid", check1[1:]) - } - if ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret) { - t.Errorf("Checksum %s should not be valid", check1[:len(check1)-1]) - } + assert.True(ValidateBackendChecksumValue(check1, rnd, body, secret), "Checksum should be valid") + assert.False(ValidateBackendChecksumValue(check1[1:], rnd, body, secret), "Checksum should not be valid") + assert.False(ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret), "Checksum should not be valid") request := &http.Request{ Header: make(http.Header), } request.Header.Set("Spreed-Signaling-Random", rnd) request.Header.Set("Spreed-Signaling-Checksum", check1) - if !ValidateBackendChecksum(request, body, secret) { - t.Errorf("Checksum %s could not be validated from request", check1) - } + assert.True(ValidateBackendChecksum(request, body, secret), "Checksum could not be validated from request") } func TestValidNumbers(t *testing.T) { t.Parallel() + assert := assert.New(t) valid := []string{ "+12", "+12345", @@ -72,13 +66,9 @@ func TestValidNumbers(t *testing.T) { "+123-45", } for _, number := range valid { - if !isValidNumber(number) { - t.Errorf("number %s should be valid", number) - } + assert.True(isValidNumber(number), "number %s should be valid", number) } for _, number := range invalid { - if isValidNumber(number) { - t.Errorf("number %s should not be valid", number) - } + assert.False(isValidNumber(number), "number %s should not be valid", number) } } diff --git a/api_signaling_test.go b/api_signaling_test.go index f7bec376..b33169ba 100644 --- a/api_signaling_test.go +++ b/api_signaling_test.go @@ -24,9 +24,10 @@ package signaling import ( "encoding/json" "fmt" - "reflect" "sort" "testing" + + "github.com/stretchr/testify/assert" ) type testCheckValid interface { @@ -53,40 +54,34 @@ func wrapMessage(messageType string, msg testCheckValid) *ClientMessage { } func testMessages(t *testing.T, messageType string, valid_messages []testCheckValid, invalid_messages []testCheckValid) { + t.Helper() + assert := assert.New(t) for _, msg := range valid_messages { - if err := msg.CheckValid(); err != nil { - t.Errorf("Message %+v should be valid, got %s", msg, err) - } + assert.NoError(msg.CheckValid(), "Message %+v should be valid", msg) + // If the inner message is valid, it should also be valid in a wrapped // ClientMessage. - if wrapped := wrapMessage(messageType, msg); wrapped == nil { - t.Errorf("Unknown message type: %s", messageType) - } else if err := wrapped.CheckValid(); err != nil { - t.Errorf("Message %+v should be valid, got %s", wrapped, err) + if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) { + assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped) } } for _, msg := range invalid_messages { - if err := msg.CheckValid(); err == nil { - t.Errorf("Message %+v should not be valid", msg) - } + assert.Error(msg.CheckValid(), "Message %+v should not be valid", msg) // If the inner message is invalid, it should also be invalid in a // wrapped ClientMessage. - if wrapped := wrapMessage(messageType, msg); wrapped == nil { - t.Errorf("Unknown message type: %s", messageType) - } else if err := wrapped.CheckValid(); err == nil { - t.Errorf("Message %+v should not be valid", wrapped) + if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) { + assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped) } } } func TestClientMessage(t *testing.T) { t.Parallel() + assert := assert.New(t) // The message needs a type. msg := ClientMessage{} - if err := msg.CheckValid(); err == nil { - t.Errorf("Message %+v should not be valid", msg) - } + assert.Error(msg.CheckValid()) } func TestHelloClientMessage(t *testing.T) { @@ -229,9 +224,8 @@ func TestHelloClientMessage(t *testing.T) { msg := ClientMessage{ Type: "hello", } - if err := msg.CheckValid(); err == nil { - t.Errorf("Message %+v should not be valid", msg) - } + assert := assert.New(t) + assert.Error(msg.CheckValid()) } func TestMessageClientMessage(t *testing.T) { @@ -311,9 +305,8 @@ func TestMessageClientMessage(t *testing.T) { msg := ClientMessage{ Type: "message", } - if err := msg.CheckValid(); err == nil { - t.Errorf("Message %+v should not be valid", msg) - } + assert := assert.New(t) + assert.Error(msg.CheckValid()) } func TestByeClientMessage(t *testing.T) { @@ -330,9 +323,8 @@ func TestByeClientMessage(t *testing.T) { msg := ClientMessage{ Type: "bye", } - if err := msg.CheckValid(); err != nil { - t.Errorf("Message %+v should be valid, got %s", msg, err) - } + assert := assert.New(t) + assert.NoError(msg.CheckValid()) } func TestRoomClientMessage(t *testing.T) { @@ -349,42 +341,31 @@ func TestRoomClientMessage(t *testing.T) { msg := ClientMessage{ Type: "room", } - if err := msg.CheckValid(); err == nil { - t.Errorf("Message %+v should not be valid", msg) - } + assert := assert.New(t) + assert.Error(msg.CheckValid()) } func TestErrorMessages(t *testing.T) { t.Parallel() + assert := assert.New(t) id := "request-id" msg := ClientMessage{ Id: id, } err1 := msg.NewErrorServerMessage(&Error{}) - if err1.Id != id { - t.Errorf("Expected id %s, got %+v", id, err1) - } - if err1.Type != "error" || err1.Error == nil { - t.Errorf("Expected type \"error\", got %+v", err1) - } + assert.Equal(id, err1.Id, "%+v", err1) + assert.Equal("error", err1.Type, "%+v", err1) + assert.NotNil(err1.Error, "%+v", err1) err2 := msg.NewWrappedErrorServerMessage(fmt.Errorf("test-error")) - if err2.Id != id { - t.Errorf("Expected id %s, got %+v", id, err2) - } - if err2.Type != "error" || err2.Error == nil { - t.Errorf("Expected type \"error\", got %+v", err2) - } - if err2.Error.Code != "internal_error" { - t.Errorf("Expected code \"internal_error\", got %+v", err2) - } - if err2.Error.Message != "test-error" { - t.Errorf("Expected message \"test-error\", got %+v", err2) + assert.Equal(id, err2.Id, "%+v", err2) + assert.Equal("error", err2.Type, "%+v", err2) + if assert.NotNil(err2.Error, "%+v", err2) { + assert.Equal("internal_error", err2.Error.Code, "%+v", err2) + assert.Equal("test-error", err2.Error.Message, "%+v", err2) } // Test "error" interface - if err2.Error.Error() != "test-error" { - t.Errorf("Expected error string \"test-error\", got %+v", err2) - } + assert.Equal("test-error", err2.Error.Error(), "%+v", err2) } func TestIsChatRefresh(t *testing.T) { @@ -397,9 +378,7 @@ func TestIsChatRefresh(t *testing.T) { Data: data_true, }, } - if !msg.IsChatRefresh() { - t.Error("message should be detected as chat refresh") - } + assert.True(t, msg.IsChatRefresh()) data_false := []byte("{\"type\":\"chat\",\"chat\":{\"refresh\":false}}") msg = ServerMessage{ @@ -408,9 +387,7 @@ func TestIsChatRefresh(t *testing.T) { Data: data_false, }, } - if msg.IsChatRefresh() { - t.Error("message should not be detected as chat refresh") - } + assert.False(t, msg.IsChatRefresh()) } func assertEqualStrings(t *testing.T, expected, result []string) { @@ -427,27 +404,22 @@ func assertEqualStrings(t *testing.T, expected, result []string) { sort.Strings(result) } - if !reflect.DeepEqual(expected, result) { - t.Errorf("Expected %+v, got %+v", expected, result) - } + assert.Equal(t, expected, result) } func Test_Welcome_AddRemoveFeature(t *testing.T) { t.Parallel() + assert := assert.New(t) var msg WelcomeServerMessage assertEqualStrings(t, []string{}, msg.Features) msg.AddFeature("one", "two", "one") assertEqualStrings(t, []string{"one", "two"}, msg.Features) - if !sort.StringsAreSorted(msg.Features) { - t.Errorf("features should be sorted, got %+v", msg.Features) - } + assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features) msg.AddFeature("three") assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features) - if !sort.StringsAreSorted(msg.Features) { - t.Errorf("features should be sorted, got %+v", msg.Features) - } + assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features) msg.RemoveFeature("three", "one") assertEqualStrings(t, []string{"two"}, msg.Features) diff --git a/async_events_test.go b/async_events_test.go index 3253c832..b72a30a0 100644 --- a/async_events_test.go +++ b/async_events_test.go @@ -25,6 +25,8 @@ import ( "context" "strings" "testing" + + "github.com/stretchr/testify/require" ) var ( @@ -51,7 +53,7 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents { url := startLocalNatsServer(t) events, err := NewAsyncEvents(url) if err != nil { - t.Fatal(err) + require.NoError(t, err) } return events } @@ -59,7 +61,7 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents { func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents { events, err := NewAsyncEvents(NatsLoopbackUrl) if err != nil { - t.Fatal(err) + require.NoError(t, err) } t.Cleanup(func() { diff --git a/backend_client_test.go b/backend_client_test.go index 7effed4a..d0a4d779 100644 --- a/backend_client_test.go +++ b/backend_client_test.go @@ -24,17 +24,17 @@ package signaling import ( "context" "encoding/json" - "errors" "io" "net/http" "net/http/httptest" "net/url" - "reflect" "strings" "testing" "github.com/dlintw/goconf" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) { @@ -57,37 +57,29 @@ func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) { } data, err := json.Marshal(response) - if err != nil { - t.Fatal(err) - return - } + require.NoError(t, err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - if _, err := w.Write(data); err != nil { - t.Error(err) - } + _, err = w.Write(data) + assert.NoError(t, err) } func TestPostOnRedirect(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/ocs/v2.php/two", http.StatusTemporaryRedirect) }) r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - return - } + require.NoError(err) var request map[string]string - if err := json.Unmarshal(body, &request); err != nil { - t.Fatal(err) - return - } + err = json.Unmarshal(body, &request) + require.NoError(err) returnOCS(t, w, body) }) @@ -96,9 +88,7 @@ func TestPostOnRedirect(t *testing.T) { defer server.Close() u, err := url.Parse(server.URL + "/ocs/v2.php/one") - if err != nil { - t.Fatal(err) - } + require.NoError(err) config := goconf.NewConfigFile() config.AddOption("backend", "allowed", u.Host) @@ -107,9 +97,7 @@ func TestPostOnRedirect(t *testing.T) { config.AddOption("backend", "allowhttp", "true") } client, err := NewBackendClient(config, 1, "0.0", nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) ctx := context.Background() request := map[string]string{ @@ -117,18 +105,17 @@ func TestPostOnRedirect(t *testing.T) { } var response map[string]string err = client.PerformJSONRequest(ctx, u, request, &response) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if response == nil || !reflect.DeepEqual(request, response) { - t.Errorf("Expected %+v, got %+v", request, response) + if assert.NotNil(t, response) { + assert.Equal(t, request, response) } } func TestPostOnRedirectDifferentHost(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "http://domain.invalid/ocs/v2.php/two", http.StatusTemporaryRedirect) @@ -137,9 +124,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) { defer server.Close() u, err := url.Parse(server.URL + "/ocs/v2.php/one") - if err != nil { - t.Fatal(err) - } + require.NoError(err) config := goconf.NewConfigFile() config.AddOption("backend", "allowed", u.Host) @@ -148,9 +133,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) { config.AddOption("backend", "allowhttp", "true") } client, err := NewBackendClient(config, 1, "0.0", nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) ctx := context.Background() request := map[string]string{ @@ -160,41 +143,33 @@ func TestPostOnRedirectDifferentHost(t *testing.T) { err = client.PerformJSONRequest(ctx, u, request, &response) if err != nil { // The redirect to a different host should have failed. - if !errors.Is(err, ErrNotRedirecting) { - t.Fatal(err) - } + require.ErrorIs(err, ErrNotRedirecting) } else { - t.Fatal("The redirect should have failed") + require.Fail("The redirect should have failed") } } func TestPostOnRedirectStatusFound(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound) }) r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - return - } - - if len(body) > 0 { - t.Errorf("Should not have received any body, got %s", string(body)) - } + require.NoError(err) + assert.Empty(string(body), "Should not have received any body, got %s", string(body)) returnOCS(t, w, []byte("{}")) }) server := httptest.NewServer(r) defer server.Close() u, err := url.Parse(server.URL + "/ocs/v2.php/one") - if err != nil { - t.Fatal(err) - } + require.NoError(err) config := goconf.NewConfigFile() config.AddOption("backend", "allowed", u.Host) @@ -203,9 +178,7 @@ func TestPostOnRedirectStatusFound(t *testing.T) { config.AddOption("backend", "allowhttp", "true") } client, err := NewBackendClient(config, 1, "0.0", nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) ctx := context.Background() request := map[string]string{ @@ -213,18 +186,16 @@ func TestPostOnRedirectStatusFound(t *testing.T) { } var response map[string]string err = client.PerformJSONRequest(ctx, u, request, &response) - if err != nil { - t.Error(err) - } - - if len(response) > 0 { - t.Errorf("Expected empty response, got %+v", response) + if assert.NoError(err) { + assert.Empty(response, "Expected empty response, got %+v", response) } } func TestHandleThrottled(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { returnOCS(t, w, []byte("[]")) @@ -233,9 +204,7 @@ func TestHandleThrottled(t *testing.T) { defer server.Close() u, err := url.Parse(server.URL + "/ocs/v2.php/one") - if err != nil { - t.Fatal(err) - } + require.NoError(err) config := goconf.NewConfigFile() config.AddOption("backend", "allowed", u.Host) @@ -244,9 +213,7 @@ func TestHandleThrottled(t *testing.T) { config.AddOption("backend", "allowhttp", "true") } client, err := NewBackendClient(config, 1, "0.0", nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) ctx := context.Background() request := map[string]string{ @@ -254,9 +221,7 @@ func TestHandleThrottled(t *testing.T) { } var response map[string]string err = client.PerformJSONRequest(ctx, u, request, &response) - if err == nil { - t.Error("should have triggered an error") - } else if !errors.Is(err, ErrThrottledResponse) { - t.Error(err) + if assert.Error(err) { + assert.ErrorIs(err, ErrThrottledResponse) } } diff --git a/backend_configuration_test.go b/backend_configuration_test.go index b2c2b277..e467cbd8 100644 --- a/backend_configuration_test.go +++ b/backend_configuration_test.go @@ -22,7 +22,6 @@ package signaling import ( - "bytes" "context" "net/url" "reflect" @@ -31,32 +30,30 @@ import ( "github.com/dlintw/goconf" "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) { for _, u := range valid_urls { u := u t.Run(u, func(t *testing.T) { + assert := assert.New(t) parsed, err := url.ParseRequestURI(u) - if err != nil { - t.Errorf("The url %s should be valid, got %s", u, err) + if !assert.NoError(err, "The url %s should be valid", u) { return } - if !config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should be allowed", u) - } - if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) { - t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret)) - } + assert.True(config.IsUrlAllowed(parsed), "The url %s should be allowed", u) + secret := config.GetSecret(parsed) + assert.Equal(string(testBackendSecret), string(secret), "Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret)) }) } for _, u := range invalid_urls { u := u t.Run(u, func(t *testing.T) { + assert := assert.New(t) parsed, _ := url.ParseRequestURI(u) - if config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should not be allowed", u) - } + assert.False(config.IsUrlAllowed(parsed), "The url %s should not be allowed", u) }) } } @@ -65,28 +62,24 @@ func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]str for _, entry := range valid_urls { entry := entry t.Run(entry[0], func(t *testing.T) { + assert := assert.New(t) u := entry[0] parsed, err := url.ParseRequestURI(u) - if err != nil { - t.Errorf("The url %s should be valid, got %s", u, err) + if !assert.NoError(err, "The url %s should be valid", u) { return } - if !config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should be allowed", u) - } + assert.True(config.IsUrlAllowed(parsed), "The url %s should be allowed", u) s := entry[1] - if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) { - t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret)) - } + secret := config.GetSecret(parsed) + assert.Equal(s, string(secret), "Expected secret %s for url %s, got %s", s, u, string(secret)) }) } for _, u := range invalid_urls { u := u t.Run(u, func(t *testing.T) { + assert := assert.New(t) parsed, _ := url.ParseRequestURI(u) - if config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should not be allowed", u) - } + assert.False(config.IsUrlAllowed(parsed), "The url %s should not be allowed", u) }) } } @@ -108,9 +101,7 @@ func TestIsUrlAllowed_Compat(t *testing.T) { config.AddOption("backend", "allowhttp", "true") config.AddOption("backend", "secret", string(testBackendSecret)) cfg, err := NewBackendConfiguration(config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } @@ -130,9 +121,7 @@ func TestIsUrlAllowed_CompatForceHttps(t *testing.T) { config.AddOption("backend", "allowed", "domain.invalid") config.AddOption("backend", "secret", string(testBackendSecret)) cfg, err := NewBackendConfiguration(config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } @@ -176,9 +165,7 @@ func TestIsUrlAllowed(t *testing.T) { config.AddOption("lala", "url", "https://otherdomain.invalid/") config.AddOption("lala", "secret", string(testBackendSecret)+"-lala") cfg, err := NewBackendConfiguration(config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) testBackends(t, cfg, valid_urls, invalid_urls) } @@ -194,9 +181,7 @@ func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) { config.AddOption("backend", "allowed", "") config.AddOption("backend", "secret", string(testBackendSecret)) cfg, err := NewBackendConfiguration(config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } @@ -215,9 +200,7 @@ func TestIsUrlAllowed_AllowAll(t *testing.T) { config.AddOption("backend", "allowed", "") config.AddOption("backend", "secret", string(testBackendSecret)) cfg, err := NewBackendConfiguration(config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } @@ -238,16 +221,16 @@ func TestParseBackendIds(t *testing.T) { {"backend1,backend2, backend1", []string{"backend1", "backend2"}}, } + assert := assert.New(t) for _, test := range testcases { ids := getConfiguredBackendIDs(test.s) - if !reflect.DeepEqual(ids, test.ids) { - t.Errorf("List of ids differs, expected %+v, got %+v", test.ids, ids) - } + assert.Equal(test.ids, ids, "List of ids differs for \"%s\"", test.s) } } func TestBackendReloadNoChange(t *testing.T) { CatchLogForTest(t) + require := require.New(t) current := testutil.ToFloat64(statsBackendsCurrent) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -257,9 +240,7 @@ func TestBackendReloadNoChange(t *testing.T) { original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") o_cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+2) new_config := goconf.NewConfigFile() @@ -270,20 +251,19 @@ func TestBackendReloadNoChange(t *testing.T) { new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") n_cfg, err := NewBackendConfiguration(new_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+4) o_cfg.Reload(original_config) checkStatsValue(t, statsBackendsCurrent, current+4) if !reflect.DeepEqual(n_cfg, o_cfg) { - t.Error("BackendConfiguration should be equal after Reload") + assert.Fail(t, "BackendConfiguration should be equal after Reload") } } func TestBackendReloadChangeExistingURL(t *testing.T) { CatchLogForTest(t) + require := require.New(t) current := testutil.ToFloat64(statsBackendsCurrent) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -293,9 +273,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") o_cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+2) new_config := goconf.NewConfigFile() @@ -307,9 +285,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") n_cfg, err := NewBackendConfiguration(new_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+4) original_config.RemoveOption("backend1", "url") @@ -319,12 +295,13 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { o_cfg.Reload(original_config) checkStatsValue(t, statsBackendsCurrent, current+4) if !reflect.DeepEqual(n_cfg, o_cfg) { - t.Error("BackendConfiguration should be equal after Reload") + assert.Fail(t, "BackendConfiguration should be equal after Reload") } } func TestBackendReloadChangeSecret(t *testing.T) { CatchLogForTest(t) + require := require.New(t) current := testutil.ToFloat64(statsBackendsCurrent) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -334,9 +311,7 @@ func TestBackendReloadChangeSecret(t *testing.T) { original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") o_cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+2) new_config := goconf.NewConfigFile() @@ -347,9 +322,7 @@ func TestBackendReloadChangeSecret(t *testing.T) { new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") n_cfg, err := NewBackendConfiguration(new_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+4) original_config.RemoveOption("backend1", "secret") @@ -358,12 +331,13 @@ func TestBackendReloadChangeSecret(t *testing.T) { o_cfg.Reload(original_config) checkStatsValue(t, statsBackendsCurrent, current+4) if !reflect.DeepEqual(n_cfg, o_cfg) { - t.Error("BackendConfiguration should be equal after Reload") + assert.Fail(t, "BackendConfiguration should be equal after Reload") } } func TestBackendReloadAddBackend(t *testing.T) { CatchLogForTest(t) + require := require.New(t) current := testutil.ToFloat64(statsBackendsCurrent) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1") @@ -371,9 +345,7 @@ func TestBackendReloadAddBackend(t *testing.T) { original_config.AddOption("backend1", "url", "http://domain1.invalid") original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") o_cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+1) new_config := goconf.NewConfigFile() @@ -385,9 +357,7 @@ func TestBackendReloadAddBackend(t *testing.T) { new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") new_config.AddOption("backend2", "sessionlimit", "10") n_cfg, err := NewBackendConfiguration(new_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+3) original_config.RemoveOption("backend", "backends") @@ -399,12 +369,13 @@ func TestBackendReloadAddBackend(t *testing.T) { o_cfg.Reload(original_config) checkStatsValue(t, statsBackendsCurrent, current+4) if !reflect.DeepEqual(n_cfg, o_cfg) { - t.Error("BackendConfiguration should be equal after Reload") + assert.Fail(t, "BackendConfiguration should be equal after Reload") } } func TestBackendReloadRemoveHost(t *testing.T) { CatchLogForTest(t) + require := require.New(t) current := testutil.ToFloat64(statsBackendsCurrent) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -414,9 +385,7 @@ func TestBackendReloadRemoveHost(t *testing.T) { original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") o_cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+2) new_config := goconf.NewConfigFile() @@ -425,9 +394,7 @@ func TestBackendReloadRemoveHost(t *testing.T) { new_config.AddOption("backend1", "url", "http://domain1.invalid") new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") n_cfg, err := NewBackendConfiguration(new_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+3) original_config.RemoveOption("backend", "backends") @@ -437,12 +404,13 @@ func TestBackendReloadRemoveHost(t *testing.T) { o_cfg.Reload(original_config) checkStatsValue(t, statsBackendsCurrent, current+2) if !reflect.DeepEqual(n_cfg, o_cfg) { - t.Error("BackendConfiguration should be equal after Reload") + assert.Fail(t, "BackendConfiguration should be equal after Reload") } } func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { CatchLogForTest(t) + require := require.New(t) current := testutil.ToFloat64(statsBackendsCurrent) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -452,9 +420,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") o_cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+2) new_config := goconf.NewConfigFile() @@ -463,9 +429,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/") new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") n_cfg, err := NewBackendConfiguration(new_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) checkStatsValue(t, statsBackendsCurrent, current+3) original_config.RemoveOption("backend", "backends") @@ -475,7 +439,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { o_cfg.Reload(original_config) checkStatsValue(t, statsBackendsCurrent, current+2) if !reflect.DeepEqual(n_cfg, o_cfg) { - t.Error("BackendConfiguration should be equal after Reload") + assert.Fail(t, "BackendConfiguration should be equal after Reload") } } @@ -500,6 +464,8 @@ func mustParse(s string) *url.URL { func TestBackendConfiguration_Etcd(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) url1 := "https://domain1.invalid/foo" @@ -513,9 +479,7 @@ func TestBackendConfiguration_Etcd(t *testing.T) { config.AddOption("backend", "backendprefix", "/backends") cfg, err := NewBackendConfiguration(config, client) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer cfg.Close() storage := cfg.storage.(*backendStorageEtcd) @@ -524,31 +488,25 @@ func TestBackendConfiguration_Etcd(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if err := storage.WaitForInitialized(ctx); err != nil { - t.Fatal(err) - } + require.NoError(storage.WaitForInitialized(ctx)) - if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 { - t.Errorf("Expected one backend, got %+v", backends) - } else if backends[0].url != url1 { - t.Errorf("Expected backend url %s, got %s", url1, backends[0].url) - } else if string(backends[0].secret) != initialSecret1 { - t.Errorf("Expected backend secret %s, got %s", initialSecret1, string(backends[0].secret)) - } else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { - t.Errorf("Expected backend %+v, got %+v", backends[0], backend) + if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) && + assert.Equal(url1, backends[0].url) && + assert.Equal(initialSecret1, string(backends[0].secret)) { + if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { + assert.Fail("Expected backend %+v, got %+v", backends[0], backend) + } } drainWakeupChannel(ch) SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+secret1+"\"}")) <-ch - if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 { - t.Errorf("Expected one backend, got %+v", backends) - } else if backends[0].url != url1 { - t.Errorf("Expected backend url %s, got %s", url1, backends[0].url) - } else if string(backends[0].secret) != secret1 { - t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret)) - } else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { - t.Errorf("Expected backend %+v, got %+v", backends[0], backend) + if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) && + assert.Equal(url1, backends[0].url) && + assert.Equal(secret1, string(backends[0].secret)) { + if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { + assert.Fail("Expected backend %+v, got %+v", backends[0], backend) + } } url2 := "https://domain1.invalid/bar" @@ -557,20 +515,16 @@ func TestBackendConfiguration_Etcd(t *testing.T) { drainWakeupChannel(ch) SetEtcdValue(etcd, "/backends/2_two", []byte("{\"url\":\""+url2+"\",\"secret\":\""+secret2+"\"}")) <-ch - if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 { - t.Errorf("Expected two backends, got %+v", backends) - } else if backends[0].url != url1 { - t.Errorf("Expected backend url %s, got %s", url1, backends[0].url) - } else if string(backends[0].secret) != secret1 { - t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret)) - } else if backends[1].url != url2 { - t.Errorf("Expected backend url %s, got %s", url2, backends[1].url) - } else if string(backends[1].secret) != secret2 { - t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret)) - } else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { - t.Errorf("Expected backend %+v, got %+v", backends[0], backend) - } else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] { - t.Errorf("Expected backend %+v, got %+v", backends[1], backend) + if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) && + assert.Equal(url1, backends[0].url) && + assert.Equal(secret1, string(backends[0].secret)) && + assert.Equal(url2, backends[1].url) && + assert.Equal(secret2, string(backends[1].secret)) { + if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { + assert.Fail("Expected backend %+v, got %+v", backends[0], backend) + } else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] { + assert.Fail("Expected backend %+v, got %+v", backends[1], backend) + } } url3 := "https://domain2.invalid/foo" @@ -579,70 +533,54 @@ func TestBackendConfiguration_Etcd(t *testing.T) { drainWakeupChannel(ch) SetEtcdValue(etcd, "/backends/3_three", []byte("{\"url\":\""+url3+"\",\"secret\":\""+secret3+"\"}")) <-ch - if backends := sortBackends(cfg.GetBackends()); len(backends) != 3 { - t.Errorf("Expected three backends, got %+v", backends) - } else if backends[0].url != url1 { - t.Errorf("Expected backend url %s, got %s", url1, backends[0].url) - } else if string(backends[0].secret) != secret1 { - t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret)) - } else if backends[1].url != url2 { - t.Errorf("Expected backend url %s, got %s", url2, backends[1].url) - } else if string(backends[1].secret) != secret2 { - t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret)) - } else if backends[2].url != url3 { - t.Errorf("Expected backend url %s, got %s", url3, backends[2].url) - } else if string(backends[2].secret) != secret3 { - t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[2].secret)) - } else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { - t.Errorf("Expected backend %+v, got %+v", backends[0], backend) - } else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] { - t.Errorf("Expected backend %+v, got %+v", backends[1], backend) - } else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] { - t.Errorf("Expected backend %+v, got %+v", backends[2], backend) + if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 3) && + assert.Equal(url1, backends[0].url) && + assert.Equal(secret1, string(backends[0].secret)) && + assert.Equal(url2, backends[1].url) && + assert.Equal(secret2, string(backends[1].secret)) && + assert.Equal(url3, backends[2].url) && + assert.Equal(secret3, string(backends[2].secret)) { + if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] { + assert.Fail("Expected backend %+v, got %+v", backends[0], backend) + } else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] { + assert.Fail("Expected backend %+v, got %+v", backends[1], backend) + } else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] { + assert.Fail("Expected backend %+v, got %+v", backends[2], backend) + } } drainWakeupChannel(ch) DeleteEtcdValue(etcd, "/backends/1_one") <-ch - if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 { - t.Errorf("Expected two backends, got %+v", backends) - } else if backends[0].url != url2 { - t.Errorf("Expected backend url %s, got %s", url2, backends[0].url) - } else if string(backends[0].secret) != secret2 { - t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[0].secret)) - } else if backends[1].url != url3 { - t.Errorf("Expected backend url %s, got %s", url3, backends[1].url) - } else if string(backends[1].secret) != secret3 { - t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[1].secret)) + if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) { + assert.Equal(url2, backends[0].url) + assert.Equal(secret2, string(backends[0].secret)) + assert.Equal(url3, backends[1].url) + assert.Equal(secret3, string(backends[1].secret)) } drainWakeupChannel(ch) DeleteEtcdValue(etcd, "/backends/2_two") <-ch - if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 { - t.Errorf("Expected one backend, got %+v", backends) - } else if backends[0].url != url3 { - t.Errorf("Expected backend url %s, got %s", url3, backends[0].url) - } else if string(backends[0].secret) != secret3 { - t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[0].secret)) + if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) { + assert.Equal(url3, backends[0].url) + assert.Equal(secret3, string(backends[0].secret)) } if _, found := storage.backends["domain1.invalid"]; found { - t.Errorf("Should have removed host information for %s", "domain1.invalid") + assert.Fail("Should have removed host information for %s", "domain1.invalid") } } func TestBackendCommonSecret(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) u1, err := url.Parse("http://domain1.invalid") - if err != nil { - t.Fatal(err) - } + require.NoError(err) u2, err := url.Parse("http://domain2.invalid") - if err != nil { - t.Fatal(err) - } + require.NoError(err) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") original_config.AddOption("backend", "secret", string(testBackendSecret)) @@ -650,19 +588,13 @@ func TestBackendCommonSecret(t *testing.T) { original_config.AddOption("backend2", "url", u2.String()) original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") cfg, err := NewBackendConfiguration(original_config, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if b1 := cfg.GetBackend(u1); b1 == nil { - t.Error("didn't get backend") - } else if !bytes.Equal(b1.Secret(), testBackendSecret) { - t.Errorf("expected secret %s, got %s", string(testBackendSecret), string(b1.Secret())) + if b1 := cfg.GetBackend(u1); assert.NotNil(b1) { + assert.Equal(string(testBackendSecret), string(b1.Secret())) } - if b2 := cfg.GetBackend(u2); b2 == nil { - t.Error("didn't get backend") - } else if !bytes.Equal(b2.Secret(), []byte(string(testBackendSecret)+"-backend2")) { - t.Errorf("expected secret %s, got %s", string(testBackendSecret)+"-backend2", string(b2.Secret())) + if b2 := cfg.GetBackend(u2); assert.NotNil(b2) { + assert.Equal(string(testBackendSecret)+"-backend2", string(b2.Secret())) } updated_config := goconf.NewConfigFile() @@ -673,14 +605,10 @@ func TestBackendCommonSecret(t *testing.T) { updated_config.AddOption("backend2", "url", u2.String()) cfg.Reload(updated_config) - if b1 := cfg.GetBackend(u1); b1 == nil { - t.Error("didn't get backend") - } else if !bytes.Equal(b1.Secret(), []byte(string(testBackendSecret)+"-backend1")) { - t.Errorf("expected secret %s, got %s", string(testBackendSecret)+"-backend1", string(b1.Secret())) + if b1 := cfg.GetBackend(u1); assert.NotNil(b1) { + assert.Equal(string(testBackendSecret)+"-backend1", string(b1.Secret())) } - if b2 := cfg.GetBackend(u2); b2 == nil { - t.Error("didn't get backend") - } else if !bytes.Equal(b2.Secret(), testBackendSecret) { - t.Errorf("expected secret %s, got %s", string(testBackendSecret), string(b2.Secret())) + if b2 := cfg.GetBackend(u2); assert.NotNil(b2) { + assert.Equal(string(testBackendSecret), string(b2.Secret())) } } diff --git a/backend_server_test.go b/backend_server_test.go index 0396e743..71769934 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -35,7 +35,6 @@ import ( "net/http/httptest" "net/textproto" "net/url" - "reflect" "strings" "sync" "testing" @@ -44,6 +43,8 @@ import ( "github.com/dlintw/goconf" "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -66,6 +67,7 @@ func CreateBackendServerForTestWithTurn(t *testing.T) (*goconf.ConfigFile, *Back } func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFile) (*goconf.ConfigFile, *BackendServer, AsyncEvents, *Hub, *mux.Router, *httptest.Server) { + require := require.New(t) r := mux.NewRouter() registerBackendHandler(t, r) @@ -77,9 +79,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil config = goconf.NewConfigFile() } u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) if strings.Contains(t.Name(), "Compat") { config.AddOption("backend", "allowed", u.Host) config.AddOption("backend", "secret", string(testBackendSecret)) @@ -98,16 +98,10 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil config.AddOption("geoip", "url", "none") events := getAsyncEventsForTest(t) hub, err := NewHub(config, events, nil, nil, nil, r, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) b, err := NewBackendServer(config, hub, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b.Start(r); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b.Start(r)) go hub.Run() @@ -126,6 +120,7 @@ func CreateBackendServerWithClusteringForTest(t *testing.T) (*BackendServer, *Ba } func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *goconf.ConfigFile, config2 *goconf.ConfigFile) (*BackendServer, *BackendServer, *Hub, *Hub, *httptest.Server, *httptest.Server) { + require := require.New(t) r1 := mux.NewRouter() registerBackendHandler(t, r1) @@ -150,9 +145,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config1 = goconf.NewConfigFile() } u1, err := url.Parse(server1.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) config1.AddOption("backend", "allowed", u1.Host) if u1.Scheme == "http" { config1.AddOption("backend", "allowhttp", "true") @@ -164,25 +157,19 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config1.AddOption("geoip", "url", "none") events1, err := NewAsyncEvents(nats) - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { events1.Close() }) client1, _ := NewGrpcClientsForTest(t, addr2) hub1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) if config2 == nil { config2 = goconf.NewConfigFile() } u2, err := url.Parse(server2.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) config2.AddOption("backend", "allowed", u2.Host) if u2.Scheme == "http" { config2.AddOption("backend", "allowhttp", "true") @@ -193,32 +180,20 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config2.AddOption("clients", "internalsecret", string(testInternalSecret)) config2.AddOption("geoip", "url", "none") events2, err := NewAsyncEvents(nats) - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { events2.Close() }) client2, _ := NewGrpcClientsForTest(t, addr1) hub2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) b1, err := NewBackendServer(config1, hub1, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b1.Start(r1); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b1.Start(r1)) b2, err := NewBackendServer(config2, hub2, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b2.Start(r2); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b2.Start(r2)) go hub1.Run() go hub2.Run() @@ -278,64 +253,54 @@ func expectRoomlistEvent(ch chan *AsyncMessage, msgType string) (*EventServerMes func TestBackendServer_NoAuth(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) roomId := "the-room-id" data := []byte{'{', '}'} request, err := http.NewRequest("POST", server.URL+"/api/v1/room/"+roomId, bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } + require.NoError(err) request.Header.Set("Content-Type", "application/json") client := &http.Client{} res, err := client.Do(request) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusForbidden { - t.Errorf("Expected error response, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusForbidden, res.StatusCode, "Expected error response, got %s: %s", res.Status, string(body)) } func TestBackendServer_InvalidAuth(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) roomId := "the-room-id" data := []byte{'{', '}'} request, err := http.NewRequest("POST", server.URL+"/api/v1/room/"+roomId, bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } + require.NoError(err) request.Header.Set("Content-Type", "application/json") request.Header.Set("Spreed-Signaling-Random", "hello") request.Header.Set("Spreed-Signaling-Checksum", "world") client := &http.Client{} res, err := client.Do(request) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusForbidden { - t.Errorf("Expected error response, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusForbidden, res.StatusCode, "Expected error response, got %s: %s", res.Status, string(body)) } func TestBackendServer_OldCompatAuth(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) roomId := "the-room-id" @@ -355,14 +320,10 @@ func TestBackendServer_OldCompatAuth(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) request, err := http.NewRequest("POST", server.URL+"/api/v1/room/"+roomId, bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } + require.NoError(err) request.Header.Set("Content-Type", "application/json") rnd := newRandomString(32) check := CalculateBackendChecksum(rnd, data, testBackendSecret) @@ -370,44 +331,36 @@ func TestBackendServer_OldCompatAuth(t *testing.T) { request.Header.Set("Spreed-Signaling-Checksum", check) client := &http.Client{} res, err := client.Do(request) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusOK { - t.Errorf("Expected success, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected success, got %s: %s", res.Status, string(body)) } func TestBackendServer_InvalidBody(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) roomId := "the-room-id" data := []byte{1, 2, 3, 4} // Invalid JSON res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusBadRequest { - t.Errorf("Expected error response, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusBadRequest, res.StatusCode, "Expected error response, got %s: %s", res.Status, string(body)) } func TestBackendServer_UnsupportedRequest(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) msg := &BackendServerRoomRequest{ @@ -415,22 +368,14 @@ func TestBackendServer_UnsupportedRequest(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "the-room-id" res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusBadRequest { - t.Errorf("Expected error response, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusBadRequest, res.StatusCode, "Expected error response, got %s: %s", res.Status, string(body)) } func TestBackendServer_RoomInvite(t *testing.T) { @@ -452,12 +397,12 @@ func (l *channelEventListener) ProcessAsyncUserMessage(message *AsyncMessage) { } func RunTestBackendServer_RoomInvite(t *testing.T) { + require := require.New(t) + assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) userid := "test-userid" roomProperties := json.RawMessage("{\"foo\":\"bar\"}") @@ -467,9 +412,7 @@ func RunTestBackendServer_RoomInvite(t *testing.T) { listener := &channelEventListener{ ch: eventsChan, } - if err := events.RegisterUserListener(userid, backend, listener); err != nil { - t.Fatal(err) - } + require.NoError(events.RegisterUserListener(userid, backend, listener)) defer events.UnregisterUserListener(userid, backend, listener) msg := &BackendServerRoomRequest{ @@ -486,32 +429,20 @@ func RunTestBackendServer_RoomInvite(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "the-room-id" res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s: %s", res.Status, string(body)) - event, err := expectRoomlistEvent(eventsChan, "invite") - if err != nil { - t.Error(err) - } else if event.Invite == nil { - t.Errorf("Expected invite, got %+v", event) - } else if event.Invite.RoomId != roomId { - t.Errorf("Expected room %s, got %+v", roomId, event) - } else if !bytes.Equal(event.Invite.Properties, roomProperties) { - t.Errorf("Room properties don't match: expected %s, got %s", string(roomProperties), string(event.Invite.Properties)) + if event, err := expectRoomlistEvent(eventsChan, "invite"); assert.NoError(err) { + if assert.NotNil(event.Invite) { + assert.Equal(roomId, event.Invite.RoomId) + assert.Equal(string(roomProperties), string(event.Invite.Properties)) + } } } @@ -526,41 +457,33 @@ func TestBackendServer_RoomDisinvite(t *testing.T) { } func RunTestBackendServer_RoomDisinvite(t *testing.T) { + require := require.New(t) + assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) backend := hub.backend.GetBackend(u) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + if room, err := client.JoinRoom(ctx, roomId); assert.NoError(err) { + assert.Equal(roomId, room.Room.RoomId) } // Ignore "join" events. - if err := client.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.DrainMessages(ctx)) roomProperties := json.RawMessage("{\"foo\":\"bar\"}") @@ -568,9 +491,7 @@ func RunTestBackendServer_RoomDisinvite(t *testing.T) { listener := &channelEventListener{ ch: eventsChan, } - if err := events.RegisterUserListener(testDefaultUserId, backend, listener); err != nil { - t.Fatal(err) - } + require.NoError(events.RegisterUserListener(testDefaultUserId, backend, listener)) defer events.UnregisterUserListener(testDefaultUserId, backend, listener) msg := &BackendServerRoomRequest{ @@ -588,92 +509,64 @@ func RunTestBackendServer_RoomDisinvite(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + require.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s: %s", res.Status, string(body)) - event, err := expectRoomlistEvent(eventsChan, "disinvite") - if err != nil { - t.Error(err) - } else if event.Disinvite == nil { - t.Errorf("Expected disinvite, got %+v", event) - } else if event.Disinvite.RoomId != roomId { - t.Errorf("Expected room %s, got %+v", roomId, event) - } else if len(event.Disinvite.Properties) > 0 { - t.Errorf("Room properties should be omitted, got %s", string(event.Disinvite.Properties)) - } else if event.Disinvite.Reason != "disinvited" { - t.Errorf("Reason should be disinvited, got %s", event.Disinvite.Reason) + if event, err := expectRoomlistEvent(eventsChan, "disinvite"); assert.NoError(err) { + if assert.NotNil(event.Disinvite) { + assert.Equal(roomId, event.Disinvite.RoomId) + assert.Equal("disinvited", event.Disinvite.Reason) + } + assert.Empty(string(event.Disinvite.Properties)) } - if message, err := client.RunUntilRoomlistDisinvite(ctx); err != nil { - t.Error(err) - } else if message.RoomId != roomId { - t.Errorf("Expected message for room %s, got %s", roomId, message.RoomId) + if message, err := client.RunUntilRoomlistDisinvite(ctx); assert.NoError(err) { + assert.Equal(roomId, message.RoomId) } if message, err := client.RunUntilMessage(ctx); err != nil && !websocket.IsCloseError(err, websocket.CloseNoStatusReceived) { - t.Errorf("Received unexpected error %s", err) + assert.Fail("Received unexpected error %s", err) } else if err == nil { - t.Errorf("Server should have closed the connection, received %+v", *message) + assert.Fail("Server should have closed the connection, received %+v", *message) } } func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } - + require.NoError(client1.SendHello(testDefaultUserId)) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId1 := "test-room1" - if _, err := client1.JoinRoom(ctx, roomId1); err != nil { - t.Fatal(err) - } - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + _, err = client1.JoinRoom(ctx, roomId1) + require.NoError(err) + require.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) roomId2 := "test-room2" - if _, err := client2.JoinRoom(ctx, roomId2); err != nil { - t.Fatal(err) - } - if err := client2.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } + _, err = client2.JoinRoom(ctx, roomId2) + require.NoError(err) + require.NoError(client2.RunUntilJoined(ctx, hello2.Hello)) msg := &BackendServerRoomRequest{ Type: "disinvite", @@ -689,38 +582,26 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId1, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) - if message, err := client1.RunUntilRoomlistDisinvite(ctx); err != nil { - t.Error(err) - } else if message.RoomId != roomId1 { - t.Errorf("Expected message for room %s, got %s", roomId1, message.RoomId) + if message, err := client1.RunUntilRoomlistDisinvite(ctx); assert.NoError(err) { + assert.Equal(roomId1, message.RoomId) } if message, err := client1.RunUntilMessage(ctx); err != nil && !websocket.IsCloseError(err, websocket.CloseNoStatusReceived) { - t.Errorf("Received unexpected error %s", err) + assert.NoError(err) } else if err == nil { - t.Errorf("Server should have closed the connection, received %+v", *message) + assert.Fail("Server should have closed the connection, received %+v", *message) } - if message, err := client2.RunUntilRoomlistDisinvite(ctx); err != nil { - t.Error(err) - } else if message.RoomId != roomId1 { - t.Errorf("Expected message for room %s, got %s", roomId1, message.RoomId) + if message, err := client2.RunUntilRoomlistDisinvite(ctx); assert.NoError(err) { + assert.Equal(roomId1, message.RoomId) } msg = &BackendServerRoomRequest{ @@ -734,26 +615,16 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { } data, err = json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err = performBackendRequest(server.URL+"/api/v1/room/"+roomId2, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err = io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) - if message, err := client2.RunUntilRoomlistUpdate(ctx); err != nil { - t.Error(err) - } else if message.RoomId != roomId2 { - t.Errorf("Expected message for room %s, got %s", roomId2, message.RoomId) + if message, err := client2.RunUntilRoomlistUpdate(ctx); assert.NoError(err) { + assert.Equal(roomId2, message.RoomId) } } @@ -768,23 +639,19 @@ func TestBackendServer_RoomUpdate(t *testing.T) { } func RunTestBackendServer_RoomUpdate(t *testing.T) { + require := require.New(t) + assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "the-room-id" emptyProperties := json.RawMessage("{}") backend := hub.backend.GetBackend(u) - if backend == nil { - t.Fatalf("Did not find backend") - } + require.NotNil(backend, "Did not find backend") room, err := hub.createRoom(roomId, emptyProperties, backend) - if err != nil { - t.Fatalf("Could not create room: %s", err) - } + require.NoError(err, "Could not create room") defer room.Close() userid := "test-userid" @@ -794,9 +661,7 @@ func RunTestBackendServer_RoomUpdate(t *testing.T) { listener := &channelEventListener{ ch: eventsChan, } - if err := events.RegisterUserListener(userid, backend, listener); err != nil { - t.Fatal(err) - } + require.NoError(events.RegisterUserListener(userid, backend, listener)) defer events.UnregisterUserListener(userid, backend, listener) msg := &BackendServerRoomRequest{ @@ -810,43 +675,27 @@ func RunTestBackendServer_RoomUpdate(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) - event, err := expectRoomlistEvent(eventsChan, "update") - if err != nil { - t.Error(err) - } else if event.Update == nil { - t.Errorf("Expected update, got %+v", event) - } else if event.Update.RoomId != roomId { - t.Errorf("Expected room %s, got %+v", roomId, event) - } else if !bytes.Equal(event.Update.Properties, roomProperties) { - t.Errorf("Room properties don't match: expected %s, got %s", string(roomProperties), string(event.Update.Properties)) + if event, err := expectRoomlistEvent(eventsChan, "update"); assert.NoError(err) { + if assert.NotNil(event.Update) { + assert.Equal(roomId, event.Update.RoomId) + assert.Equal(string(roomProperties), string(event.Update.Properties)) + } } // TODO: Use event to wait for asynchronous messages. time.Sleep(10 * time.Millisecond) room = hub.getRoom(roomId) - if room == nil { - t.Fatalf("Room %s does not exist", roomId) - } - if string(room.Properties()) != string(roomProperties) { - t.Errorf("Expected properties %s for room %s, got %s", string(roomProperties), room.Id(), string(room.Properties())) - } + require.NotNil(room, "Room %s does not exist", roomId) + assert.Equal(string(roomProperties), string(room.Properties())) } func TestBackendServer_RoomDelete(t *testing.T) { @@ -860,31 +709,26 @@ func TestBackendServer_RoomDelete(t *testing.T) { } func RunTestBackendServer_RoomDelete(t *testing.T) { + require := require.New(t) + assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "the-room-id" emptyProperties := json.RawMessage("{}") backend := hub.backend.GetBackend(u) - if backend == nil { - t.Fatalf("Did not find backend") - } - if _, err := hub.createRoom(roomId, emptyProperties, backend); err != nil { - t.Fatalf("Could not create room: %s", err) - } + require.NotNil(backend, "Did not find backend") + _, err = hub.createRoom(roomId, emptyProperties, backend) + require.NoError(err) userid := "test-userid" eventsChan := make(chan *AsyncMessage, 1) listener := &channelEventListener{ ch: eventsChan, } - if err := events.RegisterUserListener(userid, backend, listener); err != nil { - t.Fatal(err) - } + require.NoError(events.RegisterUserListener(userid, backend, listener)) defer events.UnregisterUserListener(userid, backend, listener) msg := &BackendServerRoomRequest{ @@ -897,43 +741,28 @@ func RunTestBackendServer_RoomDelete(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) // A deleted room is signalled as a "disinvite" event. - event, err := expectRoomlistEvent(eventsChan, "disinvite") - if err != nil { - t.Error(err) - } else if event.Disinvite == nil { - t.Errorf("Expected disinvite, got %+v", event) - } else if event.Disinvite.RoomId != roomId { - t.Errorf("Expected room %s, got %+v", roomId, event) - } else if len(event.Disinvite.Properties) > 0 { - t.Errorf("Room properties should be omitted, got %s", string(event.Disinvite.Properties)) - } else if event.Disinvite.Reason != "deleted" { - t.Errorf("Reason should be deleted, got %s", event.Disinvite.Reason) + if event, err := expectRoomlistEvent(eventsChan, "disinvite"); assert.NoError(err) { + if assert.NotNil(event.Disinvite) { + assert.Equal(roomId, event.Disinvite.RoomId) + assert.Empty(event.Disinvite.Properties) + assert.Equal("deleted", event.Disinvite.Reason) + } } // TODO: Use event to wait for asynchronous messages. time.Sleep(10 * time.Millisecond) room := hub.getRoom(roomId) - if room != nil { - t.Errorf("Room %s should have been deleted", roomId) - } + assert.Nil(room, "Room %s should have been deleted", roomId) } func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { @@ -941,6 +770,8 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -957,35 +788,23 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId) - if session2 == nil { - t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) - } + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) // Sessions have all permissions initially (fallback for old-style sessions). assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_MEDIA) @@ -995,24 +814,16 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Ignore "join" events. - if err := client1.DrainMessages(ctx); err != nil { - t.Error(err) - } - if err := client2.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client1.DrainMessages(ctx)) + assert.NoError(client2.DrainMessages(ctx)) msg := &BackendServerRoomRequest{ Type: "participants", @@ -1041,22 +852,14 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // The request could be sent to any of the backend servers. res, err := performBackendRequest(server1.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) // TODO: Use event to wait for asynchronous messages. time.Sleep(10 * time.Millisecond) @@ -1072,26 +875,22 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) session := hub.GetSessionByPublicId(hello.Hello.SessionId) - if session == nil { - t.Fatalf("Session %s does not exist", hello.Hello.SessionId) - } + assert.NotNil(session, "Session %s does not exist", hello.Hello.SessionId) // Sessions have all permissions initially (fallback for old-style sessions). assertSessionHasPermission(t, session, PERMISSION_MAY_PUBLISH_MEDIA) @@ -1099,18 +898,12 @@ func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) { // Join room by id. roomId := "test-room" - room, err := client.JoinRoom(ctx, roomId) - if err != nil { - t.Fatal(err) - } - if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Ignore "join" events. - if err := client.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.DrainMessages(ctx)) // Updating with empty permissions upgrades to non-old-style and removes // all previously available permissions. @@ -1133,21 +926,13 @@ func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) // TODO: Use event to wait for asynchronous messages. time.Sleep(10 * time.Millisecond) @@ -1159,47 +944,37 @@ func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) { func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) @@ -1236,23 +1011,17 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) }() // Ensure the first request is being processed. @@ -1289,23 +1058,17 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) }() wg.Wait() @@ -1314,38 +1077,26 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { } msg1_a, err := client1.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if in_call_1, err := checkMessageParticipantsInCall(msg1_a); err != nil { - t.Error(err) - } else if len(in_call_1.Users) != 2 { - msg1_b, err := client1.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if in_call_2, err := checkMessageParticipantsInCall(msg1_b); err != nil { - t.Error(err) - } else if len(in_call_2.Users) != 2 { - t.Errorf("Wrong number of users received: %d, expected 2", len(in_call_2.Users)) + assert.NoError(err) + if in_call_1, err := checkMessageParticipantsInCall(msg1_a); assert.NoError(err) { + if len(in_call_1.Users) != 2 { + msg1_b, err := client1.RunUntilMessage(ctx) + assert.NoError(err) + if in_call_2, err := checkMessageParticipantsInCall(msg1_b); assert.NoError(err) { + assert.Len(in_call_2.Users, 2) + } } } msg2_a, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if in_call_1, err := checkMessageParticipantsInCall(msg2_a); err != nil { - t.Error(err) - } else if len(in_call_1.Users) != 2 { - msg2_b, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if in_call_2, err := checkMessageParticipantsInCall(msg2_b); err != nil { - t.Error(err) - } else if len(in_call_2.Users) != 2 { - t.Errorf("Wrong number of users received: %d, expected 2", len(in_call_2.Users)) + assert.NoError(err) + if in_call_1, err := checkMessageParticipantsInCall(msg2_a); assert.NoError(err) { + if len(in_call_1.Users) != 2 { + msg2_b, err := client2.RunUntilMessage(ctx) + assert.NoError(err) + if in_call_2, err := checkMessageParticipantsInCall(msg2_b); assert.NoError(err) { + assert.Len(in_call_2.Users, 2) + } } } @@ -1353,20 +1104,16 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { defer cancel2() if msg1_c, _ := client1.RunUntilMessage(ctx2); msg1_c != nil { - if in_call_2, err := checkMessageParticipantsInCall(msg1_c); err != nil { - t.Error(err) - } else if len(in_call_2.Users) != 2 { - t.Errorf("Wrong number of users received: %d, expected 2", len(in_call_2.Users)) + if in_call_2, err := checkMessageParticipantsInCall(msg1_c); assert.NoError(err) { + assert.Len(in_call_2.Users, 2) } } ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second+100*time.Millisecond) defer cancel3() if msg2_c, _ := client2.RunUntilMessage(ctx3); msg2_c != nil { - if in_call_2, err := checkMessageParticipantsInCall(msg2_c); err != nil { - t.Error(err) - } else if len(in_call_2.Users) != 2 { - t.Errorf("Wrong number of users received: %d, expected 2", len(in_call_2.Users)) + if in_call_2, err := checkMessageParticipantsInCall(msg2_c); assert.NoError(err) { + assert.Len(in_call_2.Users, 2) } } } @@ -1376,6 +1123,8 @@ func TestBackendServer_InCallAll(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -1392,70 +1141,46 @@ func TestBackendServer_InCallAll(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId) - if session1 == nil { - t.Fatalf("Could not find session %s", hello1.Hello.SessionId) - } + require.NotNil(session1, "Could not find session %s", hello1.Hello.SessionId) session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId) - if session2 == nil { - t.Fatalf("Could not find session %s", hello2.Hello.SessionId) - } + require.NotNil(session2, "Could not find session %s", hello2.Hello.SessionId) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) room1 := hub1.getRoom(roomId) - if room1 == nil { - t.Fatalf("Could not find room %s in hub1", roomId) - } + require.NotNil(room1, "Could not find room %s in hub1", roomId) room2 := hub2.getRoom(roomId) - if room2 == nil { - t.Fatalf("Could not find room %s in hub2", roomId) - } + require.NotNil(room2, "Could not find room %s in hub2", roomId) - if room1.IsSessionInCall(session1) { - t.Errorf("Session %s should not be in room %s", session1.PublicId(), room1.Id()) - } - if room2.IsSessionInCall(session2) { - t.Errorf("Session %s should not be in room %s", session2.PublicId(), room2.Id()) - } + assert.False(room1.IsSessionInCall(session1), "Session %s should not be in room %s", session1.PublicId(), room1.Id()) + assert.False(room2.IsSessionInCall(session2), "Session %s should not be in room %s", session2.PublicId(), room2.Id()) var wg sync.WaitGroup wg.Add(1) @@ -1470,23 +1195,17 @@ func TestBackendServer_InCallAll(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } res, err := performBackendRequest(server1.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) }() wg.Wait() @@ -1494,49 +1213,39 @@ func TestBackendServer_InCallAll(t *testing.T) { return } - if msg1_a, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if in_call_1, err := checkMessageParticipantsInCall(msg1_a); err != nil { - t.Error(err) - } else if !in_call_1.All { - t.Errorf("All flag not set in message %+v", in_call_1) - } else if !bytes.Equal(in_call_1.InCall, []byte("7")) { - t.Errorf("Expected inCall flag 7, got %s", string(in_call_1.InCall)) + if msg1_a, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + if in_call_1, err := checkMessageParticipantsInCall(msg1_a); assert.NoError(err) { + assert.True(in_call_1.All, "All flag not set in message %+v", in_call_1) + assert.Equal("7", string(in_call_1.InCall)) + } } - if msg2_a, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if in_call_1, err := checkMessageParticipantsInCall(msg2_a); err != nil { - t.Error(err) - } else if !in_call_1.All { - t.Errorf("All flag not set in message %+v", in_call_1) - } else if !bytes.Equal(in_call_1.InCall, []byte("7")) { - t.Errorf("Expected inCall flag 7, got %s", string(in_call_1.InCall)) + if msg2_a, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + if in_call_1, err := checkMessageParticipantsInCall(msg2_a); assert.NoError(err) { + assert.True(in_call_1.All, "All flag not set in message %+v", in_call_1) + assert.Equal("7", string(in_call_1.InCall)) + } } - if !room1.IsSessionInCall(session1) { - t.Errorf("Session %s should be in room %s", session1.PublicId(), room1.Id()) - } - if !room2.IsSessionInCall(session2) { - t.Errorf("Session %s should be in room %s", session2.PublicId(), room2.Id()) - } + assert.True(room1.IsSessionInCall(session1), "Session %s should be in room %s", session1.PublicId(), room1.Id()) + assert.True(room2.IsSessionInCall(session2), "Session %s should be in room %s", session2.PublicId(), room2.Id()) ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if message, err := client1.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client1.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel3() - if message, err := client2.RunUntilMessage(ctx3); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client2.RunUntilMessage(ctx3); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } wg.Add(1) @@ -1551,23 +1260,17 @@ func TestBackendServer_InCallAll(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } res, err := performBackendRequest(server1.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) }() wg.Wait() @@ -1575,49 +1278,39 @@ func TestBackendServer_InCallAll(t *testing.T) { return } - if msg1_a, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if in_call_1, err := checkMessageParticipantsInCall(msg1_a); err != nil { - t.Error(err) - } else if !in_call_1.All { - t.Errorf("All flag not set in message %+v", in_call_1) - } else if !bytes.Equal(in_call_1.InCall, []byte("0")) { - t.Errorf("Expected inCall flag 0, got %s", string(in_call_1.InCall)) + if msg1_a, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + if in_call_1, err := checkMessageParticipantsInCall(msg1_a); assert.NoError(err) { + assert.True(in_call_1.All, "All flag not set in message %+v", in_call_1) + assert.Equal("0", string(in_call_1.InCall)) + } } - if msg2_a, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if in_call_1, err := checkMessageParticipantsInCall(msg2_a); err != nil { - t.Error(err) - } else if !in_call_1.All { - t.Errorf("All flag not set in message %+v", in_call_1) - } else if !bytes.Equal(in_call_1.InCall, []byte("0")) { - t.Errorf("Expected inCall flag 0, got %s", string(in_call_1.InCall)) + if msg2_a, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + if in_call_1, err := checkMessageParticipantsInCall(msg2_a); assert.NoError(err) { + assert.True(in_call_1.All, "All flag not set in message %+v", in_call_1) + assert.Equal("0", string(in_call_1.InCall)) + } } - if room1.IsSessionInCall(session1) { - t.Errorf("Session %s should not be in room %s", session1.PublicId(), room1.Id()) - } - if room2.IsSessionInCall(session2) { - t.Errorf("Session %s should not be in room %s", session2.PublicId(), room2.Id()) - } + assert.False(room1.IsSessionInCall(session1), "Session %s should not be in room %s", session1.PublicId(), room1.Id()) + assert.False(room2.IsSessionInCall(session2), "Session %s should not be in room %s", session2.PublicId(), room2.Id()) ctx4, cancel4 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel4() - if message, err := client1.RunUntilMessage(ctx4); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client1.RunUntilMessage(ctx4); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } ctx5, cancel5 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel5() - if message, err := client2.RunUntilMessage(ctx5); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client2.RunUntilMessage(ctx5); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } }) } @@ -1626,34 +1319,28 @@ func TestBackendServer_InCallAll(t *testing.T) { func TestBackendServer_RoomMessage(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId + "1")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() _, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Ignore "join" events. - if err := client.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.DrainMessages(ctx)) messageData := json.RawMessage("{\"foo\":\"bar\"}") msg := &BackendServerRoomRequest{ @@ -1664,75 +1351,49 @@ func TestBackendServer_RoomMessage(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) - message, err := client.RunUntilRoomMessage(ctx) - if err != nil { - t.Error(err) - } else if message.RoomId != roomId { - t.Errorf("Expected message for room %s, got %s", roomId, message.RoomId) - } else if !bytes.Equal(messageData, message.Data) { - t.Errorf("Expected message data %s, got %s", string(messageData), string(message.Data)) + if message, err := client.RunUntilRoomMessage(ctx); assert.NoError(err) { + assert.Equal(roomId, message.RoomId) + assert.Equal(string(messageData), string(message.Data)) } } func TestBackendServer_TurnCredentials(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTestWithTurn(t) q := make(url.Values) q.Set("service", "turn") q.Set("api", turnApiKey) request, err := http.NewRequest("GET", server.URL+"/turn/credentials?"+q.Encode(), nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client := &http.Client{} res, err := client.Do(request) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) var cred TurnCredentials - if err := json.Unmarshal(body, &cred); err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal(body, &cred)) m := hmac.New(sha1.New, []byte(turnSecret)) m.Write([]byte(cred.Username)) // nolint password := base64.StdEncoding.EncodeToString(m.Sum(nil)) - if cred.Password != password { - t.Errorf("Expected password %s, got %s", password, cred.Password) - } - if cred.TTL != int64((24 * time.Hour).Seconds()) { - t.Errorf("Expected a TTL of %d, got %d", int64((24 * time.Hour).Seconds()), cred.TTL) - } - if !reflect.DeepEqual(cred.URIs, turnServers) { - t.Errorf("Expected the list of servers as %s, got %s", turnServers, cred.URIs) - } + assert.Equal(password, cred.Password) + assert.EqualValues((24 * time.Hour).Seconds(), cred.TTL) + assert.Equal(turnServers, cred.URIs) } func TestBackendServer_StatsAllowedIps(t *testing.T) { @@ -1758,12 +1419,11 @@ func TestBackendServer_StatsAllowedIps(t *testing.T) { addr := addr t.Run(addr, func(t *testing.T) { t.Parallel() + assert := assert.New(t) r1 := &http.Request{ RemoteAddr: addr, } - if !backend.allowStatsAccess(r1) { - t.Errorf("should allow %s", addr) - } + assert.True(backend.allowStatsAccess(r1), "should allow %s", addr) if host, _, err := net.SplitHostPort(addr); err == nil { addr = host @@ -1775,9 +1435,7 @@ func TestBackendServer_StatsAllowedIps(t *testing.T) { textproto.CanonicalMIMEHeaderKey("x-real-ip"): []string{addr}, }, } - if !backend.allowStatsAccess(r2) { - t.Errorf("should allow %s", addr) - } + assert.True(backend.allowStatsAccess(r2), "should allow %s", addr) r3 := &http.Request{ RemoteAddr: "1.2.3.4:12345", @@ -1785,9 +1443,7 @@ func TestBackendServer_StatsAllowedIps(t *testing.T) { textproto.CanonicalMIMEHeaderKey("x-forwarded-for"): []string{addr}, }, } - if !backend.allowStatsAccess(r3) { - t.Errorf("should allow %s", addr) - } + assert.True(backend.allowStatsAccess(r3), "should allow %s", addr) r4 := &http.Request{ RemoteAddr: "1.2.3.4:12345", @@ -1795,9 +1451,7 @@ func TestBackendServer_StatsAllowedIps(t *testing.T) { textproto.CanonicalMIMEHeaderKey("x-forwarded-for"): []string{addr + ", 1.2.3.4:23456"}, }, } - if !backend.allowStatsAccess(r4) { - t.Errorf("should allow %s", addr) - } + assert.True(backend.allowStatsAccess(r4), "should allow %s", addr) }) } @@ -1808,9 +1462,7 @@ func TestBackendServer_StatsAllowedIps(t *testing.T) { r := &http.Request{ RemoteAddr: addr, } - if backend.allowStatsAccess(r) { - t.Errorf("should not allow %s", addr) - } + assert.False(t, backend.allowStatsAccess(r), "should not allow %s", addr) }) } } @@ -1834,35 +1486,29 @@ func Test_IsNumeric(t *testing.T) { "a1", } for _, s := range numeric { - if !isNumeric(s) { - t.Errorf("%s should be numeric", s) - } + assert.True(t, isNumeric(s), "%s should be numeric", s) } for _, s := range nonNumeric { - if isNumeric(s) { - t.Errorf("%s should not be numeric", s) - } + assert.False(t, isNumeric(s), "%s should not be numeric", s) } } func TestBackendServer_DialoutNoSipBridge(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloInternal(); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloInternal()) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() _, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "12345" msg := &BackendServerRoomRequest{ @@ -1873,56 +1519,40 @@ func TestBackendServer_DialoutNoSipBridge(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusNotFound { - t.Fatalf("Expected error %d, got %s: %s", http.StatusNotFound, res.Status, string(body)) - } + assert.NoError(err) + require.Equal(http.StatusNotFound, res.StatusCode, "Expected error, got %s", string(body)) var response BackendServerRoomResponse - if err := json.Unmarshal(body, &response); err != nil { - t.Fatal(err) - } - - if response.Type != "dialout" || response.Dialout == nil { - t.Fatalf("expected type dialout, got %s", string(body)) - } - if response.Dialout.Error == nil { - t.Fatalf("expected dialout error, got %s", string(body)) - } - if expected := "no_client_available"; response.Dialout.Error.Code != expected { - t.Errorf("expected error code %s, got %s", expected, string(body)) + if assert.NoError(json.Unmarshal(body, &response)) { + assert.Equal("dialout", response.Type) + if assert.NotNil(response.Dialout) && + assert.NotNil(response.Dialout.Error) { + assert.Equal("no_client_available", response.Dialout.Error.Code) + } } } func TestBackendServer_DialoutAccepted(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloInternalWithFeatures([]string{"start-dialout"}); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloInternalWithFeatures([]string{"start-dialout"})) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() _, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "12345" callId := "call-123" @@ -1932,22 +1562,19 @@ func TestBackendServer_DialoutAccepted(t *testing.T) { defer close(stopped) msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } - if msg.Type != "internal" || msg.Internal.Type != "dialout" { - t.Errorf("expected internal dialout message, got %+v", msg) + if !assert.Equal("internal", msg.Type) || + !assert.NotNil(msg.Internal) || + !assert.Equal("dialout", msg.Internal.Type) || + !assert.NotNil(msg.Internal.Dialout) { return } - if msg.Internal.Dialout.RoomId != roomId { - t.Errorf("expected room id %s, got %+v", roomId, msg) - } - if url := server.URL + "/"; msg.Internal.Dialout.Backend != url { - t.Errorf("expected backend %s, got %+v", url, msg) - } + assert.Equal(roomId, msg.Internal.Dialout.RoomId) + assert.Equal(server.URL+"/", msg.Internal.Dialout.Backend) response := &ClientMessage{ Id: msg.Id, @@ -1964,9 +1591,7 @@ func TestBackendServer_DialoutAccepted(t *testing.T) { }, }, } - if err := client.WriteJSON(response); err != nil { - t.Error(err) - } + assert.NoError(client.WriteJSON(response)) }() defer func() { @@ -1981,56 +1606,40 @@ func TestBackendServer_DialoutAccepted(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusOK { - t.Fatalf("Expected error %d, got %s: %s", http.StatusOK, res.Status, string(body)) - } + assert.NoError(err) + require.Equal(http.StatusOK, res.StatusCode, "Expected success, got %s", string(body)) var response BackendServerRoomResponse - if err := json.Unmarshal(body, &response); err != nil { - t.Fatal(err) - } - - if response.Type != "dialout" || response.Dialout == nil { - t.Fatalf("expected type dialout, got %s", string(body)) - } - if response.Dialout.Error != nil { - t.Fatalf("expected dialout success, got %s", string(body)) - } - if response.Dialout.CallId != callId { - t.Errorf("expected call id %s, got %s", callId, string(body)) + if err := json.Unmarshal(body, &response); assert.NoError(err) { + assert.Equal("dialout", response.Type) + if assert.NotNil(response.Dialout) { + assert.Nil(response.Dialout.Error, "expected dialout success, got %s", string(body)) + assert.Equal(callId, response.Dialout.CallId) + } } } func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloInternalWithFeatures([]string{"start-dialout"}); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloInternalWithFeatures([]string{"start-dialout"})) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() _, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "12345" callId := "call-123" @@ -2040,22 +1649,19 @@ func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { defer close(stopped) msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } - if msg.Type != "internal" || msg.Internal.Type != "dialout" { - t.Errorf("expected internal dialout message, got %+v", msg) + if !assert.Equal("internal", msg.Type) || + !assert.NotNil(msg.Internal) || + !assert.Equal("dialout", msg.Internal.Type) || + !assert.NotNil(msg.Internal.Dialout) { return } - if msg.Internal.Dialout.RoomId != roomId { - t.Errorf("expected room id %s, got %+v", roomId, msg) - } - if url := server.URL + "/"; msg.Internal.Dialout.Backend != url { - t.Errorf("expected backend %s, got %+v", url, msg) - } + assert.Equal(roomId, msg.Internal.Dialout.RoomId) + assert.Equal(server.URL+"/", msg.Internal.Dialout.Backend) response := &ClientMessage{ Id: msg.Id, @@ -2072,9 +1678,7 @@ func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { }, }, } - if err := client.WriteJSON(response); err != nil { - t.Error(err) - } + assert.NoError(client.WriteJSON(response)) }() defer func() { @@ -2089,56 +1693,40 @@ func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusOK { - t.Fatalf("Expected error %d, got %s: %s", http.StatusOK, res.Status, string(body)) - } + assert.NoError(err) + require.Equal(http.StatusOK, res.StatusCode, "Expected success, got %s", string(body)) var response BackendServerRoomResponse - if err := json.Unmarshal(body, &response); err != nil { - t.Fatal(err) - } - - if response.Type != "dialout" || response.Dialout == nil { - t.Fatalf("expected type dialout, got %s", string(body)) - } - if response.Dialout.Error != nil { - t.Fatalf("expected dialout success, got %s", string(body)) - } - if response.Dialout.CallId != callId { - t.Errorf("expected call id %s, got %s", callId, string(body)) + if err := json.Unmarshal(body, &response); assert.NoError(err) { + assert.Equal("dialout", response.Type) + if assert.NotNil(response.Dialout) { + assert.Nil(response.Dialout.Error, "expected dialout success, got %s", string(body)) + assert.Equal(callId, response.Dialout.CallId) + } } } func TestBackendServer_DialoutRejected(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloInternalWithFeatures([]string{"start-dialout"}); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloInternalWithFeatures([]string{"start-dialout"})) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() _, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "12345" errorCode := "error-code" @@ -2149,22 +1737,19 @@ func TestBackendServer_DialoutRejected(t *testing.T) { defer close(stopped) msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } - if msg.Type != "internal" || msg.Internal.Type != "dialout" { - t.Errorf("expected internal dialout message, got %+v", msg) + if !assert.Equal("internal", msg.Type) || + !assert.NotNil(msg.Internal) || + !assert.Equal("dialout", msg.Internal.Type) || + !assert.NotNil(msg.Internal.Dialout) { return } - if msg.Internal.Dialout.RoomId != roomId { - t.Errorf("expected room id %s, got %+v", roomId, msg) - } - if url := server.URL + "/"; msg.Internal.Dialout.Backend != url { - t.Errorf("expected backend %s, got %+v", url, msg) - } + assert.Equal(roomId, msg.Internal.Dialout.RoomId) + assert.Equal(server.URL+"/", msg.Internal.Dialout.Backend) response := &ClientMessage{ Id: msg.Id, @@ -2177,9 +1762,7 @@ func TestBackendServer_DialoutRejected(t *testing.T) { }, }, } - if err := client.WriteJSON(response); err != nil { - t.Error(err) - } + assert.NoError(client.WriteJSON(response)) }() defer func() { @@ -2194,37 +1777,21 @@ func TestBackendServer_DialoutRejected(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusBadGateway { - t.Fatalf("Expected error %d, got %s: %s", http.StatusBadGateway, res.Status, string(body)) - } + assert.NoError(err) + require.Equal(http.StatusBadGateway, res.StatusCode, "Expected error, got %s", string(body)) var response BackendServerRoomResponse - if err := json.Unmarshal(body, &response); err != nil { - t.Fatal(err) - } - - if response.Type != "dialout" || response.Dialout == nil { - t.Fatalf("expected type dialout, got %s", string(body)) - } - if response.Dialout.Error == nil { - t.Fatalf("expected dialout error, got %s", string(body)) - } - if response.Dialout.Error.Code != errorCode { - t.Errorf("expected error code %s, got %s", errorCode, string(body)) - } - if response.Dialout.Error.Message != errorMessage { - t.Errorf("expected error message %s, got %s", errorMessage, string(body)) + if json.Unmarshal(body, &response); assert.NoError(err) { + assert.Equal("dialout", response.Type) + if assert.NotNil(response.Dialout) && + assert.NotNil(response.Dialout.Error, "expected dialout error, got %s", string(body)) { + assert.Equal(errorCode, response.Dialout.Error.Code) + assert.Equal(errorMessage, response.Dialout.Error.Message) + } } } diff --git a/backend_storage_etcd_test.go b/backend_storage_etcd_test.go index 6fc02504..d9bd77c8 100644 --- a/backend_storage_etcd_test.go +++ b/backend_storage_etcd_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/dlintw/goconf" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/server/v3/embed" ) @@ -67,9 +68,7 @@ func Test_BackendStorageEtcdNoLeak(t *testing.T) { config.AddOption("backend", "backendprefix", "/backends") cfg, err := NewBackendConfiguration(config, client) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) <-tl.closed cfg.Close() diff --git a/backoff_test.go b/backoff_test.go index 2bbc55e2..87de0489 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -25,17 +25,20 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBackoff_Exponential(t *testing.T) { t.Parallel() - backoff, err := NewExponentialBackoff(100*time.Millisecond, 500*time.Millisecond) - if err != nil { - t.Fatal(err) - } + assert := assert.New(t) + minWait := 100 * time.Millisecond + backoff, err := NewExponentialBackoff(minWait, 500*time.Millisecond) + require.NoError(t, err) waitTimes := []time.Duration{ - 100 * time.Millisecond, + minWait, 200 * time.Millisecond, 400 * time.Millisecond, 500 * time.Millisecond, @@ -43,23 +46,17 @@ func TestBackoff_Exponential(t *testing.T) { } for _, wait := range waitTimes { - if backoff.NextWait() != wait { - t.Errorf("Wait time should be %s, got %s", wait, backoff.NextWait()) - } + assert.Equal(wait, backoff.NextWait()) a := time.Now() backoff.Wait(context.Background()) b := time.Now() - if b.Sub(a) < wait { - t.Errorf("Should have waited %s, got %s", wait, b.Sub(a)) - } + assert.GreaterOrEqual(b.Sub(a), wait) } backoff.Reset() a := time.Now() backoff.Wait(context.Background()) b := time.Now() - if b.Sub(a) < 100*time.Millisecond { - t.Errorf("Should have waited %s, got %s", 100*time.Millisecond, b.Sub(a)) - } + assert.GreaterOrEqual(b.Sub(a), minWait) } diff --git a/capabilities_test.go b/capabilities_test.go index 65aa2f3b..fddefd0a 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -37,17 +37,16 @@ import ( "time" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse, http.ResponseWriter) error) (*url.URL, *Capabilities) { + require := require.New(t) pool, err := NewHttpClientPool(1, false) - if err != nil { - t.Fatal(err) - } + require.NoError(err) capabilities, err := NewCapabilities("0.0", pool) - if err != nil { - t.Fatal(err) - } + require.NoError(err) r := mux.NewRouter() server := httptest.NewServer(r) @@ -56,9 +55,7 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie }) u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) handleCapabilitiesFunc := func(w http.ResponseWriter, r *http.Request) { features := []string{ @@ -91,9 +88,7 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie } data, err := json.Marshal(response) - if err != nil { - t.Errorf("Could not marshal %+v: %s", response, err) - } + assert.NoError(t, err, "Could not marshal %+v", response) var ocs OcsResponse ocs.Ocs = &OcsBody{ @@ -104,9 +99,9 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie }, Data: data, } - if data, err = json.Marshal(ocs); err != nil { - t.Fatal(err) - } + data, err = json.Marshal(ocs) + require.NoError(err) + var cc []string if !strings.Contains(t.Name(), "NoCache") { if strings.Contains(t.Name(), "ShortCache") { @@ -177,70 +172,50 @@ func SetCapabilitiesGetNow(t *testing.T, capabilities *Capabilities, f func() ti func TestCapabilities(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) url, capabilities := NewCapabilitiesForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if !capabilities.HasCapabilityFeature(ctx, url, "foo") { - t.Error("should have capability \"foo\"") - } - if capabilities.HasCapabilityFeature(ctx, url, "lala") { - t.Error("should not have capability \"lala\"") - } + assert.True(capabilities.HasCapabilityFeature(ctx, url, "foo")) + assert.False(capabilities.HasCapabilityFeature(ctx, url, "lala")) expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.True(cached) } - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { - t.Errorf("should not have found value for \"baz\", got %s", value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); assert.False(found, "should not have found value for \"baz\", got %s", value) { + assert.True(cached) } - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { - t.Errorf("should not have found value for \"invalid\", got %s", value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); assert.False(found, "should not have found value for \"invalid\", got %s", value) { + assert.True(cached) } - if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { - t.Errorf("should not have found value for \"baz\", got %s", value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); assert.False(found, "should not have found value for \"baz\", got %s", value) { + assert.True(cached) } expectedInt := 42 - if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { - t.Error("could not find value for \"baz\"") - } else if value != expectedInt { - t.Errorf("expected value %d, got %d", expectedInt, value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); assert.True(found) { + assert.Equal(expectedInt, value) + assert.True(cached) } - if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { - t.Errorf("should not have found value for \"foo\", got %d", value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); assert.False(found, "should not have found value for \"foo\", got %d", value) { + assert.True(cached) } - if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { - t.Errorf("should not have found value for \"invalid\", got %d", value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); assert.False(found, "should not have found value for \"invalid\", got %d", value) { + assert.True(cached) } - if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { - t.Errorf("should not have found value for \"baz\", got %d", value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); assert.False(found, "should not have found value for \"baz\", got %d", value) { + assert.True(cached) } } func TestInvalidateCapabilities(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { called.Add(1) @@ -251,47 +226,35 @@ func TestInvalidateCapabilities(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) // Invalidating will cause the capabilities to be reloaded. capabilities.InvalidateCapabilities(url) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) // Invalidating is throttled to about once per minute. capabilities.InvalidateCapabilities(url) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.True(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) // At a later time, invalidating can be done again. SetCapabilitiesGetNow(t, capabilities, func() time.Time { @@ -300,22 +263,19 @@ func TestInvalidateCapabilities(t *testing.T) { capabilities.InvalidateCapabilities(url) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 3 { - t.Errorf("expected called %d, got %d", 3, value) - } + value = called.Load() + assert.EqualValues(3, value) } func TestCapabilitiesNoCache(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { called.Add(1) @@ -326,51 +286,40 @@ func TestCapabilitiesNoCache(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) // Capabilities are cached for some time if no "Cache-Control" header is set. - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.True(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value = called.Load() + assert.EqualValues(1, value) SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(minCapabilitiesCacheDuration) }) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) } func TestCapabilitiesShortCache(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { called.Add(1) @@ -381,76 +330,58 @@ func TestCapabilitiesShortCache(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) // Capabilities are cached for some time if no "Cache-Control" header is set. - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.True(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value = called.Load() + assert.EqualValues(1, value) // The capabilities are cached for a minumum duration. SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(minCapabilitiesCacheDuration / 2) }) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if !cached { - t.Errorf("expected cached response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.True(cached) } SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(minCapabilitiesCacheDuration) }) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) } func TestCapabilitiesNoCacheETag(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { ct := w.Header().Get("Content-Type") switch called.Add(1) { case 1: - if ct == "" { - t.Error("expected content-type on first request") - } + assert.NotEmpty(ct, "expected content-type on first request") case 2: - if ct != "" { - t.Errorf("expected no content-type on second request, got %s", ct) - } + assert.Empty(ct, "expected no content-type on second request") } return nil }) @@ -459,38 +390,31 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(minCapabilitiesCacheDuration) }) - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) } func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { if called.Add(1) == 2 { @@ -504,17 +428,13 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(time.Minute) @@ -522,22 +442,19 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { // Expired capabilities can still be used even in case of update errors if // "must-revalidate" is not set. - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) } func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { if called.Add(1) == 2 { @@ -551,17 +468,13 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(minCapabilitiesCacheDuration) @@ -569,22 +482,19 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { // Expired capabilities can still be used even in case of update errors if // "must-revalidate" is not set. - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) } func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { if called.Add(1) == 2 { @@ -598,17 +508,13 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { defer cancel() expectedString := "bar" - if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { - t.Error("could not find value for \"foo\"") - } else if value != expectedString { - t.Errorf("expected value %s, got %s", expectedString, value) - } else if cached { - t.Errorf("expected direct response") + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) } - if value := called.Load(); value != 1 { - t.Errorf("expected called %d, got %d", 1, value) - } + value := called.Load() + assert.EqualValues(1, value) SetCapabilitiesGetNow(t, capabilities, func() time.Time { return time.Now().Add(minCapabilitiesCacheDuration) @@ -616,11 +522,9 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { // Capabilities will be cleared if "must-revalidate" is set and an error // occurs while fetching the updated data. - if value, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); found { - t.Errorf("should not have found value for \"foo\", got %s", value) - } + capaValue, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo") + assert.False(found, "should not have found value for \"foo\", got %s", capaValue) - if value := called.Load(); value != 2 { - t.Errorf("expected called %d, got %d", 2, value) - } + value = called.Load() + assert.EqualValues(2, value) } diff --git a/channel_waiter_test.go b/channel_waiter_test.go index e401ae87..91416421 100644 --- a/channel_waiter_test.go +++ b/channel_waiter_test.go @@ -23,6 +23,8 @@ package signaling import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestChannelWaiters(t *testing.T) { @@ -42,9 +44,9 @@ func TestChannelWaiters(t *testing.T) { select { case <-ch1: - t.Error("should have not received another event") + assert.Fail(t, "should have not received another event") case <-ch2: - t.Error("should have not received another event") + assert.Fail(t, "should have not received another event") default: } @@ -60,7 +62,7 @@ func TestChannelWaiters(t *testing.T) { <-ch2 select { case <-ch3: - t.Error("should have not received another event") + assert.Fail(t, "should have not received another event") default: } } diff --git a/clientsession_test.go b/clientsession_test.go index 6d3b9a42..5840c040 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -26,6 +26,9 @@ import ( "net/url" "strconv" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -119,9 +122,7 @@ func Test_permissionsEqual(t *testing.T) { t.Run(strconv.Itoa(idx), func(t *testing.T) { t.Parallel() equal := permissionsEqual(test.a, test.b) - if equal != test.equal { - t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal]) - } + assert.Equal(t, test.equal, equal, "Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal]) }) } } @@ -129,17 +130,16 @@ func Test_permissionsEqual(t *testing.T) { func TestBandwidth_Client(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -147,31 +147,23 @@ func TestBandwidth_Client(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Client may not send an offer with audio and video. bitrate := 10000 - if err := client.SendMessage(MessageClientMessageRecipient{ + require.NoError(client.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello.Hello.SessionId, }, MessageClientMessageData{ @@ -182,22 +174,13 @@ func TestBandwidth_Client(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) pub := mcu.GetPublisher(hello.Hello.SessionId) - if pub == nil { - t.Fatal("Could not find publisher") - } - - if pub.bitrate != bitrate { - t.Errorf("Expected bitrate %d, got %d", bitrate, pub.bitrate) - } + require.NotNil(pub) + assert.Equal(bitrate, pub.bitrate) } func TestBandwidth_Backend(t *testing.T) { @@ -206,13 +189,9 @@ func TestBandwidth_Backend(t *testing.T) { hub, _, _, server := CreateHubWithMultipleBackendsForTest(t) u, err := url.Parse(server.URL + "/one") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) backend := hub.backend.GetBackend(u) - if backend == nil { - t.Fatal("Could not get backend") - } + require.NotNil(t, backend, "Could not get backend") backend.maxScreenBitrate = 1000 backend.maxStreamBitrate = 2000 @@ -221,11 +200,8 @@ func TestBandwidth_Backend(t *testing.T) { defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.NoError(t, mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -237,37 +213,31 @@ func TestBandwidth_Backend(t *testing.T) { for _, streamType := range streamTypes { t.Run(string(streamType), func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() params := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params)) hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + require.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Client may not send an offer with audio and video. bitrate := 10000 - if err := client.SendMessage(MessageClientMessageRecipient{ + require.NoError(client.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello.Hello.SessionId, }, MessageClientMessageData{ @@ -278,18 +248,12 @@ func TestBandwidth_Backend(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) pub := mcu.GetPublisher(hello.Hello.SessionId) - if pub == nil { - t.Fatal("Could not find publisher") - } + require.NotNil(pub, "Could not find publisher") var expectBitrate int if streamType == StreamTypeVideo { @@ -297,9 +261,7 @@ func TestBandwidth_Backend(t *testing.T) { } else { expectBitrate = backend.maxScreenBitrate } - if pub.bitrate != expectBitrate { - t.Errorf("Expected bitrate %d, got %d", expectBitrate, pub.bitrate) - } + assert.Equal(expectBitrate, pub.bitrate) }) } } diff --git a/closer_test.go b/closer_test.go index 1a14ea15..ab23621a 100644 --- a/closer_test.go +++ b/closer_test.go @@ -24,6 +24,8 @@ package signaling import ( "sync" "testing" + + "github.com/stretchr/testify/assert" ) func TestCloserMulti(t *testing.T) { @@ -39,24 +41,16 @@ func TestCloserMulti(t *testing.T) { }() } - if closer.IsClosed() { - t.Error("should not be closed") - } + assert.False(t, closer.IsClosed()) closer.Close() - if !closer.IsClosed() { - t.Error("should be closed") - } + assert.True(t, closer.IsClosed()) wg.Wait() } func TestCloserCloseBeforeWait(t *testing.T) { closer := NewCloser() closer.Close() - if !closer.IsClosed() { - t.Error("should be closed") - } + assert.True(t, closer.IsClosed()) <-closer.C - if !closer.IsClosed() { - t.Error("should be closed") - } + assert.True(t, closer.IsClosed()) } diff --git a/concurrentmap_test.go b/concurrentmap_test.go index 0ae9a3f7..cca1d29c 100644 --- a/concurrentmap_test.go +++ b/concurrentmap_test.go @@ -25,75 +25,52 @@ import ( "strconv" "sync" "testing" + + "github.com/stretchr/testify/assert" ) func TestConcurrentStringStringMap(t *testing.T) { + assert := assert.New(t) var m ConcurrentStringStringMap - if m.Len() != 0 { - t.Errorf("Expected %d entries, got %d", 0, m.Len()) - } - if v, found := m.Get("foo"); found { - t.Errorf("Expected missing entry, got %s", v) - } + assert.Equal(0, m.Len()) + v, found := m.Get("foo") + assert.False(found, "Expected missing entry, got %s", v) m.Set("foo", "bar") - if m.Len() != 1 { - t.Errorf("Expected %d entries, got %d", 1, m.Len()) - } - if v, found := m.Get("foo"); !found { - t.Errorf("Expected entry") - } else if v != "bar" { - t.Errorf("Expected bar, got %s", v) + assert.Equal(1, m.Len()) + if v, found := m.Get("foo"); assert.True(found) { + assert.Equal("bar", v) } m.Set("foo", "baz") - if m.Len() != 1 { - t.Errorf("Expected %d entries, got %d", 1, m.Len()) - } - if v, found := m.Get("foo"); !found { - t.Errorf("Expected entry") - } else if v != "baz" { - t.Errorf("Expected baz, got %s", v) + assert.Equal(1, m.Len()) + if v, found := m.Get("foo"); assert.True(found) { + assert.Equal("baz", v) } m.Set("lala", "lolo") - if m.Len() != 2 { - t.Errorf("Expected %d entries, got %d", 2, m.Len()) - } - if v, found := m.Get("lala"); !found { - t.Errorf("Expected entry") - } else if v != "lolo" { - t.Errorf("Expected lolo, got %s", v) + assert.Equal(2, m.Len()) + if v, found := m.Get("lala"); assert.True(found) { + assert.Equal("lolo", v) } // Deleting missing entries doesn't do anything. m.Del("xyz") - if m.Len() != 2 { - t.Errorf("Expected %d entries, got %d", 2, m.Len()) + assert.Equal(2, m.Len()) + if v, found := m.Get("foo"); assert.True(found) { + assert.Equal("baz", v) } - if v, found := m.Get("foo"); !found { - t.Errorf("Expected entry") - } else if v != "baz" { - t.Errorf("Expected baz, got %s", v) - } - if v, found := m.Get("lala"); !found { - t.Errorf("Expected entry") - } else if v != "lolo" { - t.Errorf("Expected lolo, got %s", v) + if v, found := m.Get("lala"); assert.True(found) { + assert.Equal("lolo", v) } m.Del("lala") - if m.Len() != 1 { - t.Errorf("Expected %d entries, got %d", 2, m.Len()) - } - if v, found := m.Get("foo"); !found { - t.Errorf("Expected entry") - } else if v != "baz" { - t.Errorf("Expected baz, got %s", v) - } - if v, found := m.Get("lala"); found { - t.Errorf("Expected missing entry, got %s", v) + assert.Equal(1, m.Len()) + if v, found := m.Get("foo"); assert.True(found) { + assert.Equal("baz", v) } + v, found = m.Get("lala") + assert.False(found, "Expected missing entry, got %s", v) m.Clear() var wg sync.WaitGroup @@ -108,18 +85,13 @@ func TestConcurrentStringStringMap(t *testing.T) { for y := 0; y < count; y = y + 1 { value := newRandomString(32) m.Set(key, value) - if v, found := m.Get(key); !found { - t.Errorf("Expected entry for key %s", key) - return - } else if v != value { - t.Errorf("Expected value %s for key %s, got %s", value, key, v) + if v, found := m.Get(key); !assert.True(found, "Expected entry for key %s", key) || + !assert.Equal(value, v, "Unexpected value for key %s", key) { return } } }(x) } wg.Wait() - if m.Len() != concurrency { - t.Errorf("Expected %d entries, got %d", concurrency, m.Len()) - } + assert.Equal(concurrency, m.Len()) } diff --git a/config_test.go b/config_test.go index 1408020b..f0cc6197 100644 --- a/config_test.go +++ b/config_test.go @@ -22,10 +22,11 @@ package signaling import ( - "reflect" "testing" "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStringOptions(t *testing.T) { @@ -46,13 +47,8 @@ func TestStringOptions(t *testing.T) { config.AddOption("default", "three", "3") options, err := GetStringOptions(config, "foo", false) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(expected, options) { - t.Errorf("expected %+v, got %+v", expected, options) - } + require.NoError(t, err) + assert.Equal(t, expected, options) } func TestStringOptionWithEnv(t *testing.T) { @@ -82,10 +78,8 @@ func TestStringOptionWithEnv(t *testing.T) { } for k, v := range expected { value, err := GetStringOptionWithEnv(config, "test", k) - if err != nil { - t.Errorf("expected value for %s, got %s", k, err) - } else if value != v { - t.Errorf("expected value %s for %s, got %s", v, k, value) + if assert.NoError(t, err, "expected value for %s", k) { + assert.Equal(t, v, value, "unexpected value for %s", k) } } diff --git a/deferred_executor_test.go b/deferred_executor_test.go index 3a3659a4..8fa96ce4 100644 --- a/deferred_executor_test.go +++ b/deferred_executor_test.go @@ -24,6 +24,8 @@ package signaling import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestDeferredExecutor_MultiClose(t *testing.T) { @@ -53,9 +55,7 @@ func TestDeferredExecutor_QueueSize(t *testing.T) { b := time.Now() delta := b.Sub(a) // Allow one millisecond less delay to account for time variance on CI runners. - if delta+time.Millisecond < delay { - t.Errorf("Expected a delay of %s, got %s", delay, delta) - } + assert.GreaterOrEqual(t, delta+time.Millisecond, delay) } func TestDeferredExecutor_Order(t *testing.T) { @@ -81,9 +81,7 @@ func TestDeferredExecutor_Order(t *testing.T) { <-done for x := 0; x < 10; x++ { - if entries[x] != x { - t.Errorf("Expected %d at position %d, got %d", x, x, entries[x]) - } + assert.Equal(t, entries[x], x, "Unexpected at position %d", x) } } @@ -108,7 +106,7 @@ func TestDeferredExecutor_DeferAfterClose(t *testing.T) { e.Close() e.Execute(func() { - t.Error("method should not have been called") + assert.Fail(t, "method should not have been called") }) } diff --git a/dns_monitor_test.go b/dns_monitor_test.go index 6f0964b6..ee107722 100644 --- a/dns_monitor_test.go +++ b/dns_monitor_test.go @@ -30,6 +30,9 @@ import ( "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type mockDnsLookup struct { @@ -82,20 +85,16 @@ func (m *mockDnsLookup) lookup(host string) ([]net.IP, error) { func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor { t.Helper() + require := require.New(t) monitor, err := NewDnsMonitor(interval) - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { monitor.Stop() }) - if err := monitor.Start(); err != nil { - t.Fatal(err) - } - + require.NoError(monitor.Start()) return monitor } @@ -148,20 +147,18 @@ func (r *dnsMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all, add, keep, re expected := r.expected r.expected = nil if expected == expectNone { - r.t.Errorf("expected no event, got %v", received) + assert.Fail(r.t, "expected no event, got %v", received) return } if expected == nil { if r.received != nil && !r.received.Equal(received) { - r.t.Errorf("already received %v, got %v", r.received, received) + assert.Fail(r.t, "already received %v, got %v", r.received, received) } return } - if !expected.Equal(received) { - r.t.Errorf("expected %v, got %v", expected, received) - } + assert.True(r.t, expected.Equal(received), "expected %v, got %v", expected, received) r.received = nil r.expected = nil } @@ -178,7 +175,7 @@ func (r *dnsMonitorReceiver) WaitForExpected(ctx context.Context) { select { case <-ticker.C: case <-ctx.Done(): - r.t.Error(ctx.Err()) + assert.NoError(r.t, ctx.Err()) abort = true } r.Lock() @@ -191,7 +188,7 @@ func (r *dnsMonitorReceiver) Expect(all, add, keep, remove []net.IP) { defer r.Unlock() if r.expected != nil && r.expected != expectNone { - r.t.Errorf("didn't get previously expected %v", r.expected) + assert.Fail(r.t, "didn't get previously expected %v", r.expected) } expected := &dnsMonitorReceiverRecord{ @@ -214,7 +211,7 @@ func (r *dnsMonitorReceiver) ExpectNone() { defer r.Unlock() if r.expected != nil && r.expected != expectNone { - r.t.Errorf("didn't get previously expected %v", r.expected) + assert.Fail(r.t, "didn't get previously expected %v", r.expected) } r.expected = expectNone @@ -241,9 +238,7 @@ func TestDnsMonitor(t *testing.T) { rec1.Expect(ips1, ips1, nil, nil) entry1, err := monitor.Add("https://foo:12345", rec1.OnLookup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer monitor.Remove(entry1) rec1.WaitForExpected(ctx) @@ -306,9 +301,7 @@ func TestDnsMonitorIP(t *testing.T) { rec1.Expect(ips, ips, nil, nil) entry, err := monitor.Add(ip+":12345", rec1.OnLookup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer monitor.Remove(entry) rec1.WaitForExpected(ctx) @@ -328,9 +321,7 @@ func TestDnsMonitorNoLookupIfEmpty(t *testing.T) { } time.Sleep(10 * interval) - if checked.Load() { - t.Error("should not have checked hostnames") - } + assert.False(t, checked.Load(), "should not have checked hostnames") } type deadlockMonitorReceiver struct { @@ -355,8 +346,7 @@ func newDeadlockMonitorReceiver(t *testing.T, monitor *DnsMonitor) *deadlockMoni } func (r *deadlockMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) { - if r.closed.Load() { - r.t.Error("received lookup after closed") + if !assert.False(r.t, r.closed.Load(), "received lookup after closed") { return } @@ -385,8 +375,7 @@ func (r *deadlockMonitorReceiver) Start() { defer r.mu.Unlock() entry, err := r.monitor.Add("foo", r.OnLookup) - if err != nil { - r.t.Errorf("error adding listener: %s", err) + if !assert.NoError(r.t, err) { return } @@ -422,7 +411,5 @@ func TestDnsMonitorDeadlock(t *testing.T) { time.Sleep(10 * interval) monitor.mu.Lock() defer monitor.mu.Unlock() - if len(monitor.hostnames) > 0 { - t.Errorf("should have cleared hostnames, got %+v", monitor.hostnames) - } + assert.Empty(t, monitor.hostnames) } diff --git a/etcd_client_test.go b/etcd_client_test.go index 6de0edcd..7986e2ec 100644 --- a/etcd_client_test.go +++ b/etcd_client_test.go @@ -34,6 +34,8 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/api/v3/mvccpb" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/embed" @@ -66,15 +68,14 @@ func isErrorAddressAlreadyInUse(err error) bool { } func NewEtcdForTest(t *testing.T) *embed.Etcd { + require := require.New(t) cfg := embed.NewConfig() cfg.Dir = t.TempDir() os.Chmod(cfg.Dir, 0700) // nolint cfg.LogLevel = "warn" u, err := url.Parse(etcdListenUrl) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Find a free port to bind the server to. var etcd *embed.Etcd @@ -94,14 +95,12 @@ func NewEtcdForTest(t *testing.T) *embed.Etcd { etcd, err = embed.StartEtcd(cfg) if isErrorAddressAlreadyInUse(err) { continue - } else if err != nil { - t.Fatal(err) } + + require.NoError(err) break } - if etcd == nil { - t.Fatal("could not find free port") - } + require.NotNil(etcd, "could not find free port") t.Cleanup(func() { etcd.Close() @@ -121,13 +120,9 @@ func NewEtcdClientForTest(t *testing.T) (*embed.Etcd, *EtcdClient) { config.AddOption("etcd", "loglevel", "error") client, err := NewEtcdClient(config, "") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { - if err := client.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, client.Close()) }) return etcd, client } @@ -149,54 +144,44 @@ func DeleteEtcdValue(etcd *embed.Etcd, key string) { func Test_EtcdClient_Get(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) - if response, err := client.Get(context.Background(), "foo"); err != nil { - t.Error(err) - } else if response.Count != 0 { - t.Errorf("expected 0 response, got %d", response.Count) + if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) { + assert.EqualValues(0, response.Count) } SetEtcdValue(etcd, "foo", []byte("bar")) - if response, err := client.Get(context.Background(), "foo"); err != nil { - t.Error(err) - } else if response.Count != 1 { - t.Errorf("expected 1 responses, got %d", response.Count) - } else if string(response.Kvs[0].Key) != "foo" { - t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key)) - } else if string(response.Kvs[0].Value) != "bar" { - t.Errorf("expected value \"bar\", got \"%s\"", string(response.Kvs[0].Value)) + if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) { + if assert.EqualValues(1, response.Count) { + assert.Equal("foo", string(response.Kvs[0].Key)) + assert.Equal("bar", string(response.Kvs[0].Value)) + } } } func Test_EtcdClient_GetPrefix(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) - if response, err := client.Get(context.Background(), "foo"); err != nil { - t.Error(err) - } else if response.Count != 0 { - t.Errorf("expected 0 response, got %d", response.Count) + if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) { + assert.EqualValues(0, response.Count) } SetEtcdValue(etcd, "foo", []byte("1")) SetEtcdValue(etcd, "foo/lala", []byte("2")) SetEtcdValue(etcd, "lala/foo", []byte("3")) - if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); err != nil { - t.Error(err) - } else if response.Count != 2 { - t.Errorf("expected 2 responses, got %d", response.Count) - } else if string(response.Kvs[0].Key) != "foo" { - t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key)) - } else if string(response.Kvs[0].Value) != "1" { - t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value)) - } else if string(response.Kvs[1].Key) != "foo/lala" { - t.Errorf("expected key \"foo/lala\", got \"%s\"", string(response.Kvs[1].Key)) - } else if string(response.Kvs[1].Value) != "2" { - t.Errorf("expected value \"2\", got \"%s\"", string(response.Kvs[1].Value)) + if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); assert.NoError(err) { + if assert.EqualValues(2, response.Count) { + assert.Equal("foo", string(response.Kvs[0].Key)) + assert.Equal("1", string(response.Kvs[0].Value)) + assert.Equal("foo/lala", string(response.Kvs[1].Key)) + assert.Equal("2", string(response.Kvs[1].Value)) + } } } @@ -237,8 +222,8 @@ func (l *EtcdClientTestListener) Close() { func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { go func() { - if err := client.WaitForConnection(l.ctx); err != nil { - l.t.Errorf("error waiting for connection: %s", err) + assert := assert.New(l.t) + if err := client.WaitForConnection(l.ctx); !assert.NoError(err) { return } @@ -246,23 +231,17 @@ func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { defer cancel() response, err := client.Get(ctx, "foo", clientv3.WithPrefix()) - if err != nil { - l.t.Error(err) - } else if response.Count != 1 { - l.t.Errorf("expected 1 responses, got %d", response.Count) - } else if string(response.Kvs[0].Key) != "foo/a" { - l.t.Errorf("expected key \"foo/a\", got \"%s\"", string(response.Kvs[0].Key)) - } else if string(response.Kvs[0].Value) != "1" { - l.t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value)) + if assert.NoError(err) && assert.EqualValues(1, response.Count) { + assert.Equal("foo/a", string(response.Kvs[0].Key)) + assert.Equal("1", string(response.Kvs[0].Value)) } close(l.initial) nextRevision := response.Header.Revision + 1 for l.ctx.Err() == nil { var err error - if nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix()); err != nil { - l.t.Error(err) - } + nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix()) + assert.NoError(err) } }() } @@ -296,6 +275,7 @@ func (l *EtcdClientTestListener) EtcdKeyDeleted(client *EtcdClient, key string, func Test_EtcdClient_Watch(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) SetEtcdValue(etcd, "foo/a", []byte("1")) @@ -310,31 +290,19 @@ func Test_EtcdClient_Watch(t *testing.T) { SetEtcdValue(etcd, "foo/b", []byte("2")) event := <-listener.events - if event.t != clientv3.EventTypePut { - t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t) - } else if event.key != "foo/b" { - t.Errorf("expected key %s, got %s", "foo/b", event.key) - } else if event.value != "2" { - t.Errorf("expected value %s, got %s", "2", event.value) - } + assert.Equal(clientv3.EventTypePut, event.t) + assert.Equal("foo/b", event.key) + assert.Equal("2", event.value) SetEtcdValue(etcd, "foo/a", []byte("3")) event = <-listener.events - if event.t != clientv3.EventTypePut { - t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t) - } else if event.key != "foo/a" { - t.Errorf("expected key %s, got %s", "foo/a", event.key) - } else if event.value != "3" { - t.Errorf("expected value %s, got %s", "3", event.value) - } + assert.Equal(clientv3.EventTypePut, event.t) + assert.Equal("foo/a", event.key) + assert.Equal("3", event.value) DeleteEtcdValue(etcd, "foo/a") event = <-listener.events - if event.t != clientv3.EventTypeDelete { - t.Errorf("expected type %d, got %d", clientv3.EventTypeDelete, event.t) - } else if event.key != "foo/a" { - t.Errorf("expected key %s, got %s", "foo/a", event.key) - } else if event.prevValue != "3" { - t.Errorf("expected previous value %s, got %s", "3", event.prevValue) - } + assert.Equal(clientv3.EventTypeDelete, event.t) + assert.Equal("foo/a", event.key) + assert.Equal("3", event.prevValue) } diff --git a/federation_test.go b/federation_test.go index f6724570..ea4a5044 100644 --- a/federation_test.go +++ b/federation_test.go @@ -330,10 +330,10 @@ func Test_Federation(t *testing.T) { ctx2, cancel2 := context.WithTimeout(ctx, 200*time.Millisecond) defer cancel2() - if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else { - assert.Nil(message) + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } } @@ -345,10 +345,10 @@ func Test_Federation(t *testing.T) { ctx2, cancel2 := context.WithTimeout(ctx, 200*time.Millisecond) defer cancel2() - if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else { - assert.Nil(message) + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } } diff --git a/file_watcher_test.go b/file_watcher_test.go index 175844f7..14dec157 100644 --- a/file_watcher_test.go +++ b/file_watcher_test.go @@ -23,10 +23,12 @@ package signaling import ( "context" - "errors" "os" "path" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -34,38 +36,31 @@ var ( ) func TestFileWatcher_NotExist(t *testing.T) { + assert := assert.New(t) tmpdir := t.TempDir() - w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {}) - if err == nil { - t.Error("should not be able to watch non-existing files") - if err := w.Close(); err != nil { - t.Error(err) + if w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {}); !assert.ErrorIs(err, os.ErrNotExist) { + if w != nil { + assert.NoError(w.Close()) } - } else if !errors.Is(err, os.ErrNotExist) { - t.Error(err) } } func TestFileWatcher_File(t *testing.T) { ensureNoGoroutinesLeak(t, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) tmpdir := t.TempDir() filename := path.Join(tmpdir, "test.txt") - if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644)) modified := make(chan struct{}) w, err := NewFileWatcher(filename, func(filename string) { modified <- struct{}{} }) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer w.Close() - if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(filename, []byte("Updated"), 0644)) <-modified ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) @@ -73,13 +68,11 @@ func TestFileWatcher_File(t *testing.T) { select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } - if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(filename, []byte("Updated"), 0644)) <-modified ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout) @@ -87,45 +80,39 @@ func TestFileWatcher_File(t *testing.T) { select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } }) } func TestFileWatcher_Rename(t *testing.T) { + require := require.New(t) + assert := assert.New(t) tmpdir := t.TempDir() filename := path.Join(tmpdir, "test.txt") - if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644)) modified := make(chan struct{}) w, err := NewFileWatcher(filename, func(filename string) { modified <- struct{}{} }) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer w.Close() filename2 := path.Join(tmpdir, "test.txt.tmp") - if err := os.WriteFile(filename2, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(filename2, []byte("Updated"), 0644)) ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) defer cancel() select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } - if err := os.Rename(filename2, filename); err != nil { - t.Fatal(err) - } + require.NoError(os.Rename(filename2, filename)) <-modified ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout) @@ -133,35 +120,29 @@ func TestFileWatcher_Rename(t *testing.T) { select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } } func TestFileWatcher_Symlink(t *testing.T) { + require := require.New(t) + assert := assert.New(t) tmpdir := t.TempDir() sourceFilename := path.Join(tmpdir, "test1.txt") - if err := os.WriteFile(sourceFilename, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename, []byte("Hello world!"), 0644)) filename := path.Join(tmpdir, "symlink.txt") - if err := os.Symlink(sourceFilename, filename); err != nil { - t.Fatal(err) - } + require.NoError(os.Symlink(sourceFilename, filename)) modified := make(chan struct{}) w, err := NewFileWatcher(filename, func(filename string) { modified <- struct{}{} }) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer w.Close() - if err := os.WriteFile(sourceFilename, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename, []byte("Updated"), 0644)) <-modified ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) @@ -169,44 +150,34 @@ func TestFileWatcher_Symlink(t *testing.T) { select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } } func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) { + require := require.New(t) + assert := assert.New(t) tmpdir := t.TempDir() sourceFilename1 := path.Join(tmpdir, "test1.txt") - if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644)) sourceFilename2 := path.Join(tmpdir, "test2.txt") - if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644)) filename := path.Join(tmpdir, "symlink.txt") - if err := os.Symlink(sourceFilename1, filename); err != nil { - t.Fatal(err) - } + require.NoError(os.Symlink(sourceFilename1, filename)) modified := make(chan struct{}) w, err := NewFileWatcher(filename, func(filename string) { modified <- struct{}{} }) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer w.Close() // Replace symlink by creating new one and rename it to the original target. - if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil { - t.Fatal(err) - } - if err := os.Rename(filename+".tmp", filename); err != nil { - t.Fatal(err) - } + require.NoError(os.Symlink(sourceFilename2, filename+".tmp")) + require.NoError(os.Rename(filename+".tmp", filename)) <-modified ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) @@ -214,89 +185,73 @@ func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) { select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } } func TestFileWatcher_OtherSymlink(t *testing.T) { + require := require.New(t) + assert := assert.New(t) tmpdir := t.TempDir() sourceFilename1 := path.Join(tmpdir, "test1.txt") - if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644)) sourceFilename2 := path.Join(tmpdir, "test2.txt") - if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644)) filename := path.Join(tmpdir, "symlink.txt") - if err := os.Symlink(sourceFilename1, filename); err != nil { - t.Fatal(err) - } + require.NoError(os.Symlink(sourceFilename1, filename)) modified := make(chan struct{}) w, err := NewFileWatcher(filename, func(filename string) { modified <- struct{}{} }) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer w.Close() - if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil { - t.Fatal(err) - } + require.NoError(os.Symlink(sourceFilename2, filename+".tmp")) ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) defer cancel() select { case <-modified: - t.Error("should not have received event for other symlink") + assert.Fail("should not have received event for other symlink") case <-ctxTimeout.Done(): } } func TestFileWatcher_RenameSymlinkTarget(t *testing.T) { + require := require.New(t) + assert := assert.New(t) tmpdir := t.TempDir() sourceFilename1 := path.Join(tmpdir, "test1.txt") - if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644)) filename := path.Join(tmpdir, "test.txt") - if err := os.Symlink(sourceFilename1, filename); err != nil { - t.Fatal(err) - } + require.NoError(os.Symlink(sourceFilename1, filename)) modified := make(chan struct{}) w, err := NewFileWatcher(filename, func(filename string) { modified <- struct{}{} }) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer w.Close() sourceFilename2 := path.Join(tmpdir, "test1.txt.tmp") - if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } + require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644)) ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) defer cancel() select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } - if err := os.Rename(sourceFilename2, sourceFilename1); err != nil { - t.Fatal(err) - } + require.NoError(os.Rename(sourceFilename2, sourceFilename1)) <-modified ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout) @@ -304,7 +259,7 @@ func TestFileWatcher_RenameSymlinkTarget(t *testing.T) { select { case <-modified: - t.Error("should not have received another event") + assert.Fail("should not have received another event") case <-ctxTimeout.Done(): } } diff --git a/flags_test.go b/flags_test.go index 9a00ee9b..16651675 100644 --- a/flags_test.go +++ b/flags_test.go @@ -25,55 +25,28 @@ import ( "sync" "sync/atomic" "testing" + + "github.com/stretchr/testify/assert" ) func TestFlags(t *testing.T) { + assert := assert.New(t) var f Flags - if f.Get() != 0 { - t.Fatalf("Expected flags 0, got %d", f.Get()) - } - if !f.Add(1) { - t.Error("expected true") - } - if f.Get() != 1 { - t.Fatalf("Expected flags 1, got %d", f.Get()) - } - if f.Add(1) { - t.Error("expected false") - } - if f.Get() != 1 { - t.Fatalf("Expected flags 1, got %d", f.Get()) - } - if !f.Add(2) { - t.Error("expected true") - } - if f.Get() != 3 { - t.Fatalf("Expected flags 3, got %d", f.Get()) - } - if !f.Remove(1) { - t.Error("expected true") - } - if f.Get() != 2 { - t.Fatalf("Expected flags 2, got %d", f.Get()) - } - if f.Remove(1) { - t.Error("expected false") - } - if f.Get() != 2 { - t.Fatalf("Expected flags 2, got %d", f.Get()) - } - if !f.Add(3) { - t.Error("expected true") - } - if f.Get() != 3 { - t.Fatalf("Expected flags 3, got %d", f.Get()) - } - if !f.Remove(1) { - t.Error("expected true") - } - if f.Get() != 2 { - t.Fatalf("Expected flags 2, got %d", f.Get()) - } + assert.EqualValues(0, f.Get()) + assert.True(f.Add(1)) + assert.EqualValues(1, f.Get()) + assert.False(f.Add(1)) + assert.EqualValues(1, f.Get()) + assert.True(f.Add(2)) + assert.EqualValues(3, f.Get()) + assert.True(f.Remove(1)) + assert.EqualValues(2, f.Get()) + assert.False(f.Remove(1)) + assert.EqualValues(2, f.Get()) + assert.True(f.Add(3)) + assert.EqualValues(3, f.Get()) + assert.True(f.Remove(1)) + assert.EqualValues(2, f.Get()) } func runConcurrentFlags(t *testing.T, count int, f func()) { @@ -106,9 +79,7 @@ func TestFlagsConcurrentAdd(t *testing.T) { added.Add(1) } }) - if added.Load() != 1 { - t.Errorf("expected only one successfull attempt, got %d", added.Load()) - } + assert.EqualValues(t, 1, added.Load(), "expected only one successfull attempt") } func TestFlagsConcurrentRemove(t *testing.T) { @@ -122,9 +93,7 @@ func TestFlagsConcurrentRemove(t *testing.T) { removed.Add(1) } }) - if removed.Load() != 1 { - t.Errorf("expected only one successfull attempt, got %d", removed.Load()) - } + assert.EqualValues(t, 1, removed.Load(), "expected only one successfull attempt") } func TestFlagsConcurrentSet(t *testing.T) { @@ -137,7 +106,5 @@ func TestFlagsConcurrentSet(t *testing.T) { set.Add(1) } }) - if set.Load() != 1 { - t.Errorf("expected only one successfull attempt, got %d", set.Load()) - } + assert.EqualValues(t, 1, set.Load(), "expected only one successfull attempt") } diff --git a/geoip_test.go b/geoip_test.go index bfcbb329..4d1a1e14 100644 --- a/geoip_test.go +++ b/geoip_test.go @@ -32,6 +32,9 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func testGeoLookupReader(t *testing.T, reader *GeoLookup) { @@ -47,14 +50,11 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) { expected := expected t.Run(ip, func(t *testing.T) { country, err := reader.LookupCountry(net.ParseIP(ip)) - if err != nil { - t.Errorf("Could not lookup %s: %s", ip, err) + if !assert.NoError(t, err, "Could not lookup %s", ip) { return } - if country != expected { - t.Errorf("Expected %s for %s, got %s", expected, ip, country) - } + assert.Equal(t, expected, country, "Unexpected country for %s", ip) }) } } @@ -79,36 +79,28 @@ func GetGeoIpUrlForTest(t *testing.T) string { func TestGeoLookup(t *testing.T) { CatchLogForTest(t) + require := require.New(t) reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t)) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer reader.Close() - if err := reader.Update(); err != nil { - t.Fatal(err) - } + require.NoError(reader.Update()) testGeoLookupReader(t, reader) } func TestGeoLookupCaching(t *testing.T) { CatchLogForTest(t) + require := require.New(t) reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t)) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer reader.Close() - if err := reader.Update(); err != nil { - t.Fatal(err) - } + require.NoError(reader.Update()) // Updating the second time will most likely return a "304 Not Modified". // Make sure this doesn't trigger an error. - if err := reader.Update(); err != nil { - t.Fatal(err) - } + require.NoError(reader.Update()) } func TestGeoLookupContinent(t *testing.T) { @@ -125,13 +117,11 @@ func TestGeoLookupContinent(t *testing.T) { expected := expected t.Run(country, func(t *testing.T) { continents := LookupContinents(country) - if len(continents) != len(expected) { - t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected) + if !assert.Equal(t, len(expected), len(continents), "Continents didn't match for %s: got %s, expected %s", country, continents, expected) { return } for idx, c := range expected { - if continents[idx] != c { - t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected) + if !assert.Equal(t, c, continents[idx], "Continents didn't match for %s: got %s, expected %s", country, continents, expected) { break } } @@ -142,36 +132,29 @@ func TestGeoLookupContinent(t *testing.T) { func TestGeoLookupCloseEmpty(t *testing.T) { CatchLogForTest(t) reader, err := NewGeoLookupFromUrl("ignore-url") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) reader.Close() } func TestGeoLookupFromFile(t *testing.T) { CatchLogForTest(t) + require := require.New(t) geoIpUrl := GetGeoIpUrlForTest(t) resp, err := http.Get(geoIpUrl) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer resp.Body.Close() body := resp.Body url := geoIpUrl if strings.HasSuffix(geoIpUrl, ".gz") { body, err = gzip.NewReader(body) - if err != nil { - t.Fatal(err) - } + require.NoError(err) url = strings.TrimSuffix(url, ".gz") } tmpfile, err := os.CreateTemp("", "geoipdb") - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { os.Remove(tmpfile.Name()) }) @@ -183,9 +166,8 @@ func TestGeoLookupFromFile(t *testing.T) { header, err := tarfile.Next() if err == io.EOF { break - } else if err != nil { - t.Fatal(err) } + require.NoError(err) if !strings.HasSuffix(header.Name, ".mmdb") { continue @@ -193,33 +175,25 @@ func TestGeoLookupFromFile(t *testing.T) { if _, err := io.Copy(tmpfile, tarfile); err != nil { tmpfile.Close() - t.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - t.Fatal(err) + require.NoError(err) } + require.NoError(tmpfile.Close()) foundDatabase = true break } } else { if _, err := io.Copy(tmpfile, body); err != nil { tmpfile.Close() - t.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - t.Fatal(err) + require.NoError(err) } + require.NoError(tmpfile.Close()) foundDatabase = true } - if !foundDatabase { - t.Fatalf("Did not find GeoIP database in download from %s", geoIpUrl) - } + require.True(foundDatabase, "Did not find GeoIP database in download from %s", geoIpUrl) reader, err := NewGeoLookupFromFile(tmpfile.Name()) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer reader.Close() testGeoLookupReader(t, reader) @@ -228,9 +202,7 @@ func TestGeoLookupFromFile(t *testing.T) { func TestIsValidContinent(t *testing.T) { for country, continents := range ContinentMap { for _, continent := range continents { - if !IsValidContinent(continent) { - t.Errorf("Continent %s of country %s is not valid", continent, country) - } + assert.True(t, IsValidContinent(continent), "Continent %s of country %s is not valid", continent, country) } } } diff --git a/grpc_client_test.go b/grpc_client_test.go index 1714b1bb..3606aec2 100644 --- a/grpc_client_test.go +++ b/grpc_client_test.go @@ -33,6 +33,8 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/server/v3/embed" ) @@ -52,9 +54,7 @@ func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} { func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, *DnsMonitor) { dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually client, err := NewGrpcClients(config, etcdClient, dnsMonitor) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { client.Close() }) @@ -78,13 +78,9 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients config.AddOption("grpc", "targetprefix", "/grpctargets") etcdClient, err := NewEtcdClient(config, "") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { - if err := etcdClient.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, etcdClient.Close()) }) return NewGrpcClientsForTestWithConfig(t, config, etcdClient) @@ -107,7 +103,7 @@ func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) { case <-ch: return case <-ctx.Done(): - t.Error("timeout waiting for event") + assert.Fail(t, "timeout waiting for event") } } @@ -125,19 +121,17 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) { client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := client.WaitForInitialized(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, client.WaitForInitialized(ctx)) - if clients := client.GetClients(); len(clients) != 2 { - t.Errorf("Expected two clients, got %+v", clients) - } + clients := client.GetClients() + assert.Len(t, clients, 2, "Expected two clients, got %+v", clients) }) } func Test_GrpcClients_EtcdUpdate(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) etcd := NewEtcdForTest(t) client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) ch := client.getWakeupChannelForTesting() @@ -145,55 +139,45 @@ func Test_GrpcClients_EtcdUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if clients := client.GetClients(); len(clients) != 0 { - t.Errorf("Expected no clients, got %+v", clients) - } + assert.Empty(client.GetClients()) drainWakeupChannel(ch) _, addr1 := NewGrpcServerForTest(t) SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != addr1 { - t.Errorf("Expected target %s, got %s", addr1, clients[0].Target()) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(addr1, clients[0].Target()) } drainWakeupChannel(ch) _, addr2 := NewGrpcServerForTest(t) SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 2 { - t.Errorf("Expected two clients, got %+v", clients) - } else if clients[0].Target() != addr1 { - t.Errorf("Expected target %s, got %s", addr1, clients[0].Target()) - } else if clients[1].Target() != addr2 { - t.Errorf("Expected target %s, got %s", addr2, clients[1].Target()) + if clients := client.GetClients(); assert.Len(clients, 2) { + assert.Equal(addr1, clients[0].Target()) + assert.Equal(addr2, clients[1].Target()) } drainWakeupChannel(ch) DeleteEtcdValue(etcd, "/grpctargets/one") waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != addr2 { - t.Errorf("Expected target %s, got %s", addr2, clients[0].Target()) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(addr2, clients[0].Target()) } drainWakeupChannel(ch) _, addr3 := NewGrpcServerForTest(t) SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr3+"\"}")) waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != addr3 { - t.Errorf("Expected target %s, got %s", addr3, clients[0].Target()) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(addr3, clients[0].Target()) } } func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) etcd := NewEtcdForTest(t) client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) ch := client.getWakeupChannelForTesting() @@ -201,18 +185,14 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if clients := client.GetClients(); len(clients) != 0 { - t.Errorf("Expected no clients, got %+v", clients) - } + assert.Empty(client.GetClients()) drainWakeupChannel(ch) _, addr1 := NewGrpcServerForTest(t) SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != addr1 { - t.Errorf("Expected target %s, got %s", addr1, clients[0].Target()) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(addr1, clients[0].Target()) } drainWakeupChannel(ch) @@ -221,25 +201,22 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) waitForEvent(ctx, t, ch) client.selfCheckWaitGroup.Wait() - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != addr1 { - t.Errorf("Expected target %s, got %s", addr1, clients[0].Target()) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(addr1, clients[0].Target()) } drainWakeupChannel(ch) DeleteEtcdValue(etcd, "/grpctargets/two") waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != addr1 { - t.Errorf("Expected target %s, got %s", addr1, clients[0].Target()) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(addr1, clients[0].Target()) } } func Test_GrpcClients_DnsDiscovery(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { + assert := assert.New(t) lookup := newMockDnsLookupForTest(t) target := "testgrpc:12345" ip1 := net.ParseIP("192.168.0.1") @@ -254,12 +231,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { defer cancel() dnsMonitor.checkHostnames() - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != targetWithIp1 { - t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) - } else if !clients[0].ip.Equal(ip1) { - t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(targetWithIp1, clients[0].Target()) + assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip) } lookup.Set("testgrpc", []net.IP{ip1, ip2}) @@ -267,16 +241,11 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { dnsMonitor.checkHostnames() waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 2 { - t.Errorf("Expected two client, got %+v", clients) - } else if clients[0].Target() != targetWithIp1 { - t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) - } else if !clients[0].ip.Equal(ip1) { - t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) - } else if clients[1].Target() != targetWithIp2 { - t.Errorf("Expected target %s, got %s", targetWithIp2, clients[1].Target()) - } else if !clients[1].ip.Equal(ip2) { - t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip) + if clients := client.GetClients(); assert.Len(clients, 2) { + assert.Equal(targetWithIp1, clients[0].Target()) + assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip) + assert.Equal(targetWithIp2, clients[1].Target()) + assert.True(clients[1].ip.Equal(ip2), "Expected IP %s, got %s", ip2, clients[1].ip) } lookup.Set("testgrpc", []net.IP{ip2}) @@ -284,12 +253,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { dnsMonitor.checkHostnames() waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != targetWithIp2 { - t.Errorf("Expected target %s, got %s", targetWithIp2, clients[0].Target()) - } else if !clients[0].ip.Equal(ip2) { - t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(targetWithIp2, clients[0].Target()) + assert.True(clients[0].ip.Equal(ip2), "Expected IP %s, got %s", ip2, clients[0].ip) } }) } @@ -297,6 +263,7 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) lookup := newMockDnsLookupForTest(t) target := "testgrpc:12345" ip1 := net.ParseIP("192.168.0.1") @@ -309,39 +276,29 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := client.WaitForInitialized(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, client.WaitForInitialized(ctx)) - if clients := client.GetClients(); len(clients) != 0 { - t.Errorf("Expected no client, got %+v", clients) - } + assert.Empty(client.GetClients()) lookup.Set("testgrpc", []net.IP{ip1}) drainWakeupChannel(ch) dnsMonitor.checkHostnames() waitForEvent(testCtx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != targetWithIp1 { - t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) - } else if !clients[0].ip.Equal(ip1) { - t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) + if clients := client.GetClients(); assert.Len(clients, 1) { + assert.Equal(targetWithIp1, clients[0].Target()) + assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip) } } func Test_GrpcClients_Encryption(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { + require := require.New(t) serverKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) clientKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey) clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey) @@ -376,14 +333,11 @@ func Test_GrpcClients_Encryption(t *testing.T) { ctx, cancel1 := context.WithTimeout(context.Background(), time.Second) defer cancel1() - if err := clients.WaitForInitialized(ctx); err != nil { - t.Fatal(err) - } + require.NoError(clients.WaitForInitialized(ctx)) for _, client := range clients.GetClients() { - if _, err := client.GetServerId(ctx); err != nil { - t.Fatal(err) - } + _, err := client.GetServerId(ctx) + require.NoError(err) } }) } diff --git a/grpc_common_test.go b/grpc_common_test.go index f9bed4a0..878efd0b 100644 --- a/grpc_common_test.go +++ b/grpc_common_test.go @@ -35,6 +35,8 @@ import ( "os" "testing" "time" + + "github.com/stretchr/testify/require" ) func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error { @@ -72,9 +74,7 @@ func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organizatio } data, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) data = pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", @@ -108,23 +108,15 @@ func WritePublicKey(key *rsa.PublicKey, filename string) error { func replaceFile(t *testing.T, filename string, data []byte, perm fs.FileMode) { t.Helper() + require := require.New(t) oldStat, err := os.Stat(filename) - if err != nil { - t.Fatalf("can't stat old file %s: %s", filename, err) - return - } + require.NoError(err, "can't stat old file %s", filename) for { - if err := os.WriteFile(filename, data, perm); err != nil { - t.Fatalf("can't write file %s: %s", filename, err) - return - } + require.NoError(os.WriteFile(filename, data, perm), "can't write file %s", filename) newStat, err := os.Stat(filename) - if err != nil { - t.Fatalf("can't stat new file %s: %s", filename, err) - return - } + require.NoError(err, "can't stat new file %s", filename) // We need different modification times. if !newStat.ModTime().Equal(oldStat.ModTime()) { diff --git a/grpc_server_test.go b/grpc_server_test.go index a83f5279..ffa6ceb9 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -37,6 +37,8 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -67,23 +69,19 @@ func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (se server, err = NewGrpcServer(config) if isErrorAddressAlreadyInUse(err) { continue - } else if err != nil { - t.Fatal(err) } + + require.NoError(t, err) break } - if server == nil { - t.Fatal("could not find free port") - } + require.NotNil(t, server, "could not find free port") // Don't match with own server id by default. server.serverId = "dont-match" go func() { - if err := server.Run(); err != nil { - t.Errorf("could not start GRPC server: %s", err) - } + assert.NoError(t, server.Run(), "could not start GRPC server") }() t.Cleanup(func() { @@ -99,10 +97,10 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) { func Test_GrpcServer_ReloadCerts(t *testing.T) { CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) org1 := "Testing certificate" cert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, key) @@ -124,24 +122,20 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { cp1 := x509.NewCertPool() if !cp1.AppendCertsFromPEM(cert1) { - t.Fatalf("could not add certificate") + require.Fail("could not add certificate") } cfg1 := &tls.Config{ RootCAs: cp1, } conn1, err := tls.Dial("tcp", addr, cfg1) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer conn1.Close() // nolint state1 := conn1.ConnectionState() - if certs := state1.PeerCertificates; len(certs) == 0 { - t.Errorf("expected certificates, got %+v", state1) - } else if len(certs[0].Subject.Organization) == 0 { - t.Errorf("expected organization, got %s", certs[0].Subject) - } else if certs[0].Subject.Organization[0] != org1 { - t.Errorf("expected organization %s, got %s", org1, certs[0].Subject) + if certs := state1.PeerCertificates; assert.NotEmpty(certs) { + if assert.NotEmpty(certs[0].Subject.Organization) { + assert.Equal(org1, certs[0].Subject.Organization[0]) + } } org2 := "Updated certificate" @@ -151,43 +145,34 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := server.WaitForCertificateReload(ctx); err != nil { - t.Fatal(err) - } + require.NoError(server.WaitForCertificateReload(ctx)) cp2 := x509.NewCertPool() if !cp2.AppendCertsFromPEM(cert2) { - t.Fatalf("could not add certificate") + require.Fail("could not add certificate") } cfg2 := &tls.Config{ RootCAs: cp2, } conn2, err := tls.Dial("tcp", addr, cfg2) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer conn2.Close() // nolint state2 := conn2.ConnectionState() - if certs := state2.PeerCertificates; len(certs) == 0 { - t.Errorf("expected certificates, got %+v", state2) - } else if len(certs[0].Subject.Organization) == 0 { - t.Errorf("expected organization, got %s", certs[0].Subject) - } else if certs[0].Subject.Organization[0] != org2 { - t.Errorf("expected organization %s, got %s", org2, certs[0].Subject) + if certs := state2.PeerCertificates; assert.NotEmpty(certs) { + if assert.NotEmpty(certs[0].Subject.Organization) { + assert.Equal(org2, certs[0].Subject.Organization[0]) + } } } func Test_GrpcServer_ReloadCA(t *testing.T) { CatchLogForTest(t) + require := require.New(t) serverKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) clientKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey) org1 := "Testing client" @@ -213,65 +198,53 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(serverCert) { - t.Fatalf("could not add certificate") + require.Fail("could not add certificate") } pair1, err := tls.X509KeyPair(clientCert1, pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey), })) - if err != nil { - t.Fatal(err) - } + require.NoError(err) cfg1 := &tls.Config{ RootCAs: pool, Certificates: []tls.Certificate{pair1}, } client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1))) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer client1.Close() // nolint ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) defer cancel1() - if _, err := client1.GetServerId(ctx1); err != nil { - t.Fatal(err) - } + _, err = client1.GetServerId(ctx1) + require.NoError(err) org2 := "Updated client" clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) replaceFile(t, caFile, clientCert2, 0755) - if err := server.WaitForCertPoolReload(ctx1); err != nil { - t.Fatal(err) - } + require.NoError(server.WaitForCertPoolReload(ctx1)) pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey), })) - if err != nil { - t.Fatal(err) - } + require.NoError(err) cfg2 := &tls.Config{ RootCAs: pool, Certificates: []tls.Certificate{pair2}, } client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2))) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer client2.Close() // nolint ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) defer cancel2() // This will fail if the CA certificate has not been reloaded by the server. - if _, err := client2.GetServerId(ctx2); err != nil { - t.Fatal(err) - } + _, err = client2.GetServerId(ctx2) + require.NoError(err) } diff --git a/http_client_pool_test.go b/http_client_pool_test.go index b4347896..dff3866f 100644 --- a/http_client_pool_test.go +++ b/http_client_pool_test.go @@ -26,52 +26,42 @@ import ( "net/url" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHttpClientPool(t *testing.T) { t.Parallel() - if _, err := NewHttpClientPool(0, false); err == nil { - t.Error("should not be possible to create empty pool") - } + require := require.New(t) + assert := assert.New(t) + _, err := NewHttpClientPool(0, false) + assert.Error(err) pool, err := NewHttpClientPool(1, false) - if err != nil { - t.Fatal(err) - } + require.NoError(err) u, err := url.Parse("http://localhost/foo/bar") - if err != nil { - t.Fatal(err) - } + require.NoError(err) ctx := context.Background() - if _, _, err := pool.Get(ctx, u); err != nil { - t.Fatal(err) - } + _, _, err = pool.Get(ctx, u) + require.NoError(err) ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond) defer cancel() - if _, _, err := pool.Get(ctx2, u); err == nil { - t.Error("fetching from empty pool should have timed out") - } else if err != context.DeadlineExceeded { - t.Errorf("fetching from empty pool should have timed out, got %s", err) - } + _, _, err = pool.Get(ctx2, u) + assert.ErrorIs(err, context.DeadlineExceeded) // Pools are separated by hostname, so can get client for different host. u2, err := url.Parse("http://local.host/foo/bar") - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if _, _, err := pool.Get(ctx, u2); err != nil { - t.Fatal(err) - } + _, _, err = pool.Get(ctx, u2) + require.NoError(err) ctx3, cancel2 := context.WithTimeout(ctx, 10*time.Millisecond) defer cancel2() - if _, _, err := pool.Get(ctx3, u2); err == nil { - t.Error("fetching from empty pool should have timed out") - } else if err != context.DeadlineExceeded { - t.Errorf("fetching from empty pool should have timed out, got %s", err) - } + _, _, err = pool.Get(ctx3, u2) + assert.ErrorIs(err, context.DeadlineExceeded) } diff --git a/hub_test.go b/hub_test.go index ee68e973..c0dafe3b 100644 --- a/hub_test.go +++ b/hub_test.go @@ -22,7 +22,6 @@ package signaling import ( - "bytes" "context" "crypto/ecdsa" "crypto/ed25519" @@ -39,7 +38,6 @@ import ( "net/http/httptest" "net/url" "os" - "reflect" "strings" "sync" "testing" @@ -49,6 +47,8 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -131,6 +131,7 @@ func getTestConfigWithMultipleBackends(server *httptest.Server) (*goconf.ConfigF } func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, AsyncEvents, *mux.Router, *httptest.Server) { + require := require.New(t) r := mux.NewRouter() registerBackendHandler(t, r) @@ -141,20 +142,12 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve events := getAsyncEventsForTest(t) config, err := getConfigFunc(server) - if err != nil { - t.Fatal(err) - } + require.NoError(err) h, err := NewHub(config, events, nil, nil, nil, r, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) b, err := NewBackendServer(config, h, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b.Start(r); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b.Start(r)) go h.Run() @@ -180,6 +173,7 @@ func CreateHubWithMultipleBackendsForTest(t *testing.T) (*Hub, AsyncEvents, *mux } func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *mux.Router, *mux.Router, *httptest.Server, *httptest.Server) { + require := require.New(t) r1 := mux.NewRouter() registerBackendHandler(t, r1) @@ -212,51 +206,31 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http } events1, err := NewAsyncEvents(nats1) - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { events1.Close() }) config1, err := getConfigFunc(server1) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client1, _ := NewGrpcClientsForTest(t, addr2) h1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) b1, err := NewBackendServer(config1, h1, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) events2, err := NewAsyncEvents(nats2) - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { events2.Close() }) config2, err := getConfigFunc(server2) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2, _ := NewGrpcClientsForTest(t, addr1) h2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version") - if err != nil { - t.Fatal(err) - } + require.NoError(err) b2, err := NewBackendServer(config2, h2, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b1.Start(r1); err != nil { - t.Fatal(err) - } - if err := b2.Start(r2); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b1.Start(r1)) + require.NoError(b2.Start(r2)) go h1.Run() go h2.Run() @@ -306,7 +280,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { h.mu.Lock() h.ru.Lock() dumpGoroutines("", os.Stderr) - t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / remoteSessions %v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, h.remoteSessions, readActive, writeActive, ctx.Err()) + assert.Fail(t, "Error waiting for clients %+v / rooms %+v / sessions %+v / remoteSessions %v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, h.remoteSessions, readActive, writeActive, ctx.Err()) h.ru.Unlock() h.mu.Unlock() return @@ -318,25 +292,22 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { func validateBackendChecksum(t *testing.T, f func(http.ResponseWriter, *http.Request, *BackendClientRequest) *BackendClientResponse) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + require := require.New(t) body, err := io.ReadAll(r.Body) - if err != nil { - t.Fatal("Error reading body: ", err) - } + require.NoError(err) rnd := r.Header.Get(HeaderBackendSignalingRandom) checksum := r.Header.Get(HeaderBackendSignalingChecksum) if rnd == "" || checksum == "" { - t.Fatalf("No checksum headers found in request to %s", r.URL) + require.Fail("No checksum headers found in request to %s", r.URL) } if verify := CalculateBackendChecksum(rnd, body, testBackendSecret); verify != checksum { - t.Fatalf("Backend checksum verification failed for request to %s", r.URL) + require.Fail("Backend checksum verification failed for request to %s", r.URL) } var request BackendClientRequest - if err := json.Unmarshal(body, &request); err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal(body, &request)) response := f(w, r, &request) if response == nil { @@ -345,9 +316,7 @@ func validateBackendChecksum(t *testing.T, f func(http.ResponseWriter, *http.Req } data, err := json.Marshal(response) - if err != nil { - t.Fatal(err) - } + require.NoError(err) if r.Header.Get("OCS-APIRequest") != "" { var ocs OcsResponse @@ -359,9 +328,8 @@ func validateBackendChecksum(t *testing.T, f func(http.ResponseWriter, *http.Req }, Data: data, } - if data, err = json.Marshal(ocs); err != nil { - t.Fatal(err) - } + data, err = json.Marshal(ocs) + require.NoError(err) } w.Header().Set("Content-Type", "application/json") @@ -371,15 +339,14 @@ func validateBackendChecksum(t *testing.T, f func(http.ResponseWriter, *http.Req } func processAuthRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { + require := require.New(t) if request.Type != "auth" || request.Auth == nil { - t.Fatalf("Expected an auth backend request, got %+v", request) + require.Fail("Expected an auth backend request, got %+v", request) } var params TestBackendClientAuthParams if len(request.Auth.Params) > 0 { - if err := json.Unmarshal(request.Auth.Params, ¶ms); err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal(request.Auth.Params, ¶ms)) } if params.UserId == "" { params.UserId = testDefaultUserId @@ -397,17 +364,17 @@ func processAuthRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re userdata := map[string]string{ "displayname": "Displayname " + params.UserId, } - if data, err := json.Marshal(userdata); err != nil { - t.Fatal(err) - } else { - response.Auth.User = data - } + data, err := json.Marshal(userdata) + require.NoError(err) + response.Auth.User = data return response } func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { + require := require.New(t) + assert := assert.New(t) if request.Type != "room" || request.Room == nil { - t.Fatalf("Expected an room backend request, got %+v", request) + require.Fail("Expected an room backend request, got %+v", request) } switch request.Room.RoomId { @@ -416,26 +383,18 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re case "test-room-takeover-room-session": // Additional checks for testcase "TestClientTakeoverRoomSession" if request.Room.Action == "leave" && request.Room.UserId == "test-userid1" { - t.Errorf("Should not receive \"leave\" event for first user, received %+v", request.Room) + assert.Fail("Should not receive \"leave\" event for first user, received %+v", request.Room) } } if strings.Contains(t.Name(), "Federation") { // Check additional fields present for federated sessions. if strings.Contains(request.Room.SessionId, "@federated") { - if actorType := request.Room.ActorType; actorType != ActorTypeFederatedUsers { - t.Errorf("expected actorType %s, received %+v", ActorTypeFederatedUsers, request.Room) - } - if actorId := request.Room.ActorId; actorId == "" { - t.Errorf("expected actorId, received %+v", request.Room) - } + assert.Equal(ActorTypeFederatedUsers, request.Room.ActorType) + assert.NotEmpty(request.Room.ActorId) } else { - if actorType := request.Room.ActorType; actorType != "" { - t.Errorf("expected empty actorType, received %+v", request.Room) - } - if actorId := request.Room.ActorId; actorId != "" { - t.Errorf("expected empty actorId, received %+v", request.Room) - } + assert.Empty(request.Room.ActorType) + assert.Empty(request.Room.ActorId) } } @@ -454,9 +413,7 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re "userid": "userid-from-sessiondata", } tmp, err := json.Marshal(data) - if err != nil { - t.Fatalf("Could not marshal %+v: %s", data, err) - } + require.NoError(err) response.Room.Session = tmp case "test-room-initial-permissions": permissions := []Permission{PERMISSION_MAY_PUBLISH_AUDIO} @@ -498,7 +455,7 @@ func clearSessionRequestHandler(t *testing.T) { // nolint func processSessionRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { if request.Type != "session" || request.Session == nil { - t.Fatalf("Expected an session backend request, got %+v", request) + require.Fail(t, "Expected an session backend request, got %+v", request) } sessionRequestHander.Lock() @@ -545,16 +502,12 @@ func storePingRequest(t *testing.T, request *BackendClientRequest) { func processPingRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { if request.Type != "ping" || request.Ping == nil { - t.Fatalf("Expected an ping backend request, got %+v", request) + require.Fail(t, "Expected an ping backend request, got %+v", request) } if request.Ping.RoomId == "test-room-with-sessiondata" { - if entries := request.Ping.Entries; len(entries) != 1 { - t.Errorf("Expected one entry, got %+v", entries) - } else { - if entries[0].UserId != "" { - t.Errorf("Expected empty userid, got %+v", entries[0]) - } + if entries := request.Ping.Entries; assert.Len(t, entries, 1) { + assert.Empty(t, entries[0].UserId) } } @@ -571,12 +524,11 @@ func processPingRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re } func ensureAuthTokens(t *testing.T) (string, string) { + require := require.New(t) if privateKey := os.Getenv("PRIVATE_AUTH_TOKEN_" + t.Name()); privateKey != "" { publicKey := os.Getenv("PUBLIC_AUTH_TOKEN_" + t.Name()) - if publicKey == "" { - // should not happen, always both keys are created - t.Fatal("public key is empty") - } + // should not happen, always both keys are created + require.NotEmpty(publicKey, "public key is empty") return privateKey, publicKey } @@ -585,55 +537,41 @@ func ensureAuthTokens(t *testing.T) (string, string) { if strings.Contains(t.Name(), "ECDSA") { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } + require.NoError(err) private, err = x509.MarshalECPrivateKey(key) - if err != nil { - t.Fatal(err) - } + require.NoError(err) private = pem.EncodeToMemory(&pem.Block{ Type: "ECDSA PRIVATE KEY", Bytes: private, }) public, err = x509.MarshalPKIXPublicKey(&key.PublicKey) - if err != nil { - t.Fatal(err) - } + require.NoError(err) public = pem.EncodeToMemory(&pem.Block{ Type: "ECDSA PUBLIC KEY", Bytes: public, }) } else if strings.Contains(t.Name(), "Ed25519") { publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + require.NoError(err) private, err = x509.MarshalPKCS8PrivateKey(privateKey) - if err != nil { - t.Fatal(err) - } + require.NoError(err) private = pem.EncodeToMemory(&pem.Block{ Type: "Ed25519 PRIVATE KEY", Bytes: private, }) public, err = x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - t.Fatal(err) - } + require.NoError(err) public = pem.EncodeToMemory(&pem.Block{ Type: "Ed25519 PUBLIC KEY", Bytes: public, }) } else { key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) private = pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", @@ -641,9 +579,7 @@ func ensureAuthTokens(t *testing.T) (string, string) { }) public, err = x509.MarshalPKIXPublicKey(&key.PublicKey) - if err != nil { - t.Fatal(err) - } + require.NoError(err) public = pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: public, @@ -660,9 +596,7 @@ func ensureAuthTokens(t *testing.T) (string, string) { func getPrivateAuthToken(t *testing.T) (key interface{}) { private, _ := ensureAuthTokens(t) data, err := base64.StdEncoding.DecodeString(private) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if strings.Contains(t.Name(), "ECDSA") { key, err = jwt.ParseECPrivateKeyFromPEM(data) } else if strings.Contains(t.Name(), "Ed25519") { @@ -670,18 +604,14 @@ func getPrivateAuthToken(t *testing.T) (key interface{}) { } else { key, err = jwt.ParseRSAPrivateKeyFromPEM(data) } - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return key } func getPublicAuthToken(t *testing.T) (key interface{}) { _, public := ensureAuthTokens(t) data, err := base64.StdEncoding.DecodeString(public) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if strings.Contains(t.Name(), "ECDSA") { key, err = jwt.ParseECPublicKeyFromPEM(data) } else if strings.Contains(t.Name(), "Ed25519") { @@ -689,9 +619,7 @@ func getPublicAuthToken(t *testing.T) (key interface{}) { } else { key, err = jwt.ParseRSAPublicKeyFromPEM(data) } - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return key } @@ -711,7 +639,7 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { case "ping": return processPingRequest(t, w, r, request) default: - t.Fatalf("Unsupported request received: %+v", request) + require.Fail(t, "Unsupported request received: %+v", request) return nil } }) @@ -749,9 +677,7 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { if (strings.Contains(t.Name(), "V2") && useV2) || strings.Contains(t.Name(), "Federation") { key := getPublicAuthToken(t) public, err := x509.MarshalPKIXPublicKey(key) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) var pemType string if strings.Contains(t.Name(), "ECDSA") { pemType = "ECDSA PUBLIC KEY" @@ -787,9 +713,7 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { } data, err := json.Marshal(response) - if err != nil { - t.Errorf("Could not marshal %+v: %s", response, err) - } + assert.NoError(t, err, "Could not marshal %+v", response) var ocs OcsResponse ocs.Ocs = &OcsBody{ @@ -800,9 +724,8 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { }, Data: data, } - if data, err = json.Marshal(ocs); err != nil { - t.Fatal(err) - } + data, err = json.Marshal(ocs) + require.NoError(t, err) w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(data) // nolint @@ -829,46 +752,43 @@ func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup { func TestWebsocketFeatures(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() conn, response, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer conn.Close() // nolint - if server := response.Header.Get("Server"); !strings.HasPrefix(server, "nextcloud-spreed-signaling/") { - t.Errorf("expected valid server header, got \"%s\"", server) - } + serverHeader := response.Header.Get("Server") + assert.True(strings.HasPrefix(serverHeader, "nextcloud-spreed-signaling/"), "expected valid server header, got \"%s\"", serverHeader) features := response.Header.Get("X-Spreed-Signaling-Features") featuresList := make(map[string]bool) for _, f := range strings.Split(features, ",") { f = strings.TrimSpace(f) if f != "" { - if _, found := featuresList[f]; found { - t.Errorf("duplicate feature id \"%s\" in \"%s\"", f, features) - } + _, found := featuresList[f] + assert.False(found, "duplicate feature id \"%s\" in \"%s\"", f, features) featuresList[f] = true } } if len(featuresList) <= 1 { - t.Errorf("expected valid features header, got \"%s\"", features) - } - if _, found := featuresList["hello-v2"]; !found { - t.Errorf("expected feature \"hello-v2\", got \"%s\"", features) + assert.Fail("expected valid features header, got \"%s\"", features) } + _, found := featuresList["hello-v2"] + assert.True(found, "expected feature \"hello-v2\", got \"%s\"", features) - if err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}); err != nil { - t.Errorf("could not write close message: %s", err) - } + assert.NoError(conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{})) } func TestInitialWelcome(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -878,22 +798,20 @@ func TestInitialWelcome(t *testing.T) { defer client.CloseWithBye() msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if msg.Type != "welcome" { - t.Errorf("Expected \"welcome\" message, got %+v", msg) - } else if msg.Welcome.Version == "" { - t.Errorf("Expected welcome version, got %+v", msg) - } else if len(msg.Welcome.Features) == 0 { - t.Errorf("Expected welcome features, got %+v", msg) + assert.Equal("welcome", msg.Type, "%+v", msg) + if assert.NotNil(msg.Welcome, "%+v", msg) { + assert.NotEmpty(msg.Welcome.Version, "%+v", msg) + assert.NotEmpty(msg.Welcome.Features, "%+v", msg) } } func TestExpectClientHello(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) // The server will send an error and close the connection if no "Hello" @@ -908,28 +826,24 @@ func TestExpectClientHello(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() message, err := client.RunUntilMessage(ctx) - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } + require.NoError(checkUnexpectedClose(err)) message2, err := client.RunUntilMessage(ctx) if message2 != nil { - t.Fatalf("Received multiple messages, already have %+v, also got %+v", message, message2) - } - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) + require.Fail("Received multiple messages, already have %+v, also got %+v", message, message2) } + require.NoError(checkUnexpectedClose(err)) - if err := checkMessageType(message, "bye"); err != nil { - t.Error(err) - } else if message.Bye.Reason != "hello_timeout" { - t.Errorf("Expected \"hello_timeout\" reason, got %+v", message.Bye) + if err := checkMessageType(message, "bye"); assert.NoError(err) { + assert.Equal("hello_timeout", message.Bye.Reason, "%+v", message.Bye) } } func TestExpectClientHelloUnsupportedVersion(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -938,49 +852,37 @@ func TestExpectClientHelloUnsupportedVersion(t *testing.T) { params := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server.URL, "0.0", "", nil, params); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloParams(server.URL, "0.0", "", nil, params)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() message, err := client.RunUntilMessage(ctx) - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } + require.NoError(checkUnexpectedClose(err)) - if err := checkMessageType(message, "error"); err != nil { - t.Error(err) - } else if message.Error.Code != "invalid_hello_version" { - t.Errorf("Expected \"invalid_hello_version\" reason, got %+v", message.Error) + if err := checkMessageType(message, "error"); assert.NoError(err) { + assert.Equal("invalid_hello_version", message.Error.Code) } } func TestClientHelloV1(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) } } @@ -988,49 +890,35 @@ func TestClientHelloV2(t *testing.T) { CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloV2(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloV2(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName) - if data == nil { - t.Fatalf("Could not decode session id: %s", hello.Hello.SessionId) - } + require.NotNil(data, "Could not decode session id: %s", hello.Hello.SessionId) hub.mu.RLock() session := hub.sessions[data.Sid] hub.mu.RUnlock() - if session == nil { - t.Fatalf("Could not get session for id %+v", data) - } + require.NotNil(session, "Could not get session for id %+v", data) var userdata map[string]string - if err := json.Unmarshal(session.UserData(), &userdata); err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal(session.UserData(), &userdata)) - if expected := "Displayname " + testDefaultUserId; userdata["displayname"] != expected { - t.Errorf("Expected displayname %s, got %s", expected, userdata["displayname"]) - } + assert.Equal("Displayname "+testDefaultUserId, userdata["displayname"]) }) } } @@ -1039,6 +927,8 @@ func TestClientHelloV2_IssuedInFuture(t *testing.T) { CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -1046,22 +936,16 @@ func TestClientHelloV2_IssuedInFuture(t *testing.T) { issuedAt := time.Now().Add(time.Minute) expiresAt := issuedAt.Add(time.Second) - if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() message, err := client.RunUntilMessage(ctx) - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } + require.NoError(checkUnexpectedClose(err)) - if err := checkMessageType(message, "error"); err != nil { - t.Error(err) - } else if message.Error.Code != "token_not_valid_yet" { - t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error) + if err := checkMessageType(message, "error"); assert.NoError(err) { + assert.Equal("token_not_valid_yet", message.Error.Code, "%+v", message) } }) } @@ -1071,28 +955,24 @@ func TestClientHelloV2_Expired(t *testing.T) { CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() issuedAt := time.Now().Add(-time.Minute) - if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, issuedAt.Add(time.Second)); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, issuedAt.Add(time.Second))) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() message, err := client.RunUntilMessage(ctx) - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } + require.NoError(checkUnexpectedClose(err)) - if err := checkMessageType(message, "error"); err != nil { - t.Error(err) - } else if message.Error.Code != "token_expired" { - t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error) + if err := checkMessageType(message, "error"); assert.NoError(err) { + assert.Equal("token_expired", message.Error.Code, "%+v", message) } }) } @@ -1102,6 +982,8 @@ func TestClientHelloV2_IssuedAtMissing(t *testing.T) { CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -1109,22 +991,16 @@ func TestClientHelloV2_IssuedAtMissing(t *testing.T) { var issuedAt time.Time expiresAt := time.Now().Add(time.Minute) - if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() message, err := client.RunUntilMessage(ctx) - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } + require.NoError(checkUnexpectedClose(err)) - if err := checkMessageType(message, "error"); err != nil { - t.Error(err) - } else if message.Error.Code != "token_not_valid_yet" { - t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error) + if err := checkMessageType(message, "error"); assert.NoError(err) { + assert.Equal("token_not_valid_yet", message.Error.Code, "%+v", message) } }) } @@ -1134,6 +1010,8 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -1141,22 +1019,16 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { issuedAt := time.Now().Add(-time.Minute) var expiresAt time.Time - if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() message, err := client.RunUntilMessage(ctx) - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } + require.NoError(checkUnexpectedClose(err)) - if err := checkMessageType(message, "error"); err != nil { - t.Error(err) - } else if message.Error.Code != "token_expired" { - t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error) + if err := checkMessageType(message, "error"); assert.NoError(err) { + assert.Equal("token_expired", message.Error.Code, "%+v", message) } }) } @@ -1166,6 +1038,8 @@ func TestClientHelloV2_CachedCapabilities(t *testing.T) { CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -1177,20 +1051,12 @@ func TestClientHelloV2_CachedCapabilities(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHelloV1(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } - if hello1.Hello.UserId != testDefaultUserId+"1" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello) - } - if hello1.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello1.Hello) - } + require.NoError(err) + assert.Equal(testDefaultUserId+"1", hello1.Hello.UserId, "%+v", hello1.Hello) + assert.NotEmpty(hello1.Hello.SessionId, "%+v", hello1.Hello) // Simulate updated Nextcloud with capabilities for Hello V2. t.Setenv("SKIP_V2_CAPABILITIES", "") @@ -1198,20 +1064,12 @@ func TestClientHelloV2_CachedCapabilities(t *testing.T) { client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloV2(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } - if hello2.Hello.UserId != testDefaultUserId+"2" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello) - } - if hello2.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello2.Hello) - } + require.NoError(err) + assert.Equal(testDefaultUserId+"2", hello2.Hello.UserId, "%+v", hello2.Hello) + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2.Hello) }) } } @@ -1219,34 +1077,30 @@ func TestClientHelloV2_CachedCapabilities(t *testing.T) { func TestClientHelloWithSpaces(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() userId := "test user with spaces" - if err := client.SendHello(userId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(userId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != userId { - t.Errorf("Expected \"%s\", got %+v", userId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(userId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) } } func TestClientHelloAllowAll(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { config, err := getTestConfig(server) if err != nil { @@ -1261,22 +1115,14 @@ func TestClientHelloAllowAll(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) } } @@ -1285,6 +1131,8 @@ func TestClientHelloSessionLimit(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -1358,22 +1206,14 @@ func TestClientHelloSessionLimit(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", nil, params1); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", nil, params1)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) } // The second client can't connect as it would exceed the session limit. @@ -1383,42 +1223,26 @@ func TestClientHelloSessionLimit(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: testDefaultUserId + "2", } - if err := client2.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", nil, params2); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", nil, params2)) - msg, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "session_limit_exceeded" { - t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("session_limit_exceeded", msg.Error.Code, "%+v", msg) } } // The client can connect to a different backend. - if err := client2.SendHelloParams(server1.URL+"/two", HelloVersionV1, "client", nil, params2); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloParams(server1.URL+"/two", HelloVersionV1, "client", nil, params2)) - if hello, err := client2.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId+"2" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + if hello, err := client2.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId+"2", hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) } // If the first client disconnects (and releases the session), a new one can connect. client.CloseWithBye() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) client3 := NewTestClient(t, server2, hub2) defer client3.CloseWithBye() @@ -1426,19 +1250,11 @@ func TestClientHelloSessionLimit(t *testing.T) { params3 := TestBackendClientAuthParams{ UserId: testDefaultUserId + "3", } - if err := client3.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", nil, params3); err != nil { - t.Fatal(err) - } + require.NoError(client3.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", nil, params3)) - if hello, err := client3.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId+"3" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + if hello, err := client3.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId+"3", hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) } }) } @@ -1447,6 +1263,8 @@ func TestClientHelloSessionLimit(t *testing.T) { func TestSessionIdsUnordered(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) publicSessionIds := make([]string, 0) @@ -1454,36 +1272,24 @@ func TestSessionIdsUnordered(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - break - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - break - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName) - if data == nil { - t.Errorf("Could not decode session id: %s", hello.Hello.SessionId) + if !assert.NotNil(data, "Could not decode session id: %s", hello.Hello.SessionId) { break } hub.mu.RLock() session := hub.sessions[data.Sid] hub.mu.RUnlock() - if session == nil { - t.Errorf("Could not get session for id %+v", data) + if !assert.NotNil(session, "Could not get session for id %+v", data) { break } @@ -1491,9 +1297,7 @@ func TestSessionIdsUnordered(t *testing.T) { } } - if len(publicSessionIds) == 0 { - t.Fatal("no session ids decoded") - } + require.NotEmpty(publicSessionIds, "no session ids decoded") larger := 0 smaller := 0 @@ -1505,80 +1309,57 @@ func TestSessionIdsUnordered(t *testing.T) { } else if sid < prevSid { smaller-- } else { - t.Error("should not have received the same session id twice") + assert.Fail("should not have received the same session id twice") } } prevSid = sid } // Public session ids should not be ordered. - if len(publicSessionIds) == larger { - t.Error("the session ids are all larger than the previous ones") - } else if len(publicSessionIds) == smaller { - t.Error("the session ids are all smaller than the previous ones") - } + assert.NotEqual(larger, len(publicSessionIds), "the session ids are all larger than the previous ones") + assert.NotEqual(smaller, len(publicSessionIds), "the session ids are all smaller than the previous ones") } func TestClientHelloResume(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + require.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client.Close() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) client = NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } - hello2, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } + require.NoError(client.SendHelloResume(hello.Hello.ResumeId)) + if hello2, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) } } func TestClientHelloResumeThrottle(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) timing := &throttlerTiming{ @@ -1597,46 +1378,28 @@ func TestClientHelloResumeThrottle(t *testing.T) { defer client.CloseWithBye() timing.expectedSleep = 100 * time.Millisecond - if err := client.SendHelloResume("this-is-invalid"); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloResume("this-is-invalid")) - if msg, err := client.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", msg.Error.Code) + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } client = NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId) + assert.NotEmpty(hello.Hello.SessionId) + assert.NotEmpty(hello.Hello.ResumeId) client.Close() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) // Perform housekeeping in the future, this will cause the session to be // cleaned up after it is expired. @@ -1647,16 +1410,11 @@ func TestClientHelloResumeThrottle(t *testing.T) { // Valid but expired resume ids will not be throttled. timing.expectedSleep = 0 * time.Millisecond - if err := client.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } - if msg, err := client.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", msg.Error.Code) + require.NoError(client.SendHelloResume(hello.Hello.ResumeId)) + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } @@ -1664,17 +1422,12 @@ func TestClientHelloResumeThrottle(t *testing.T) { defer client.CloseWithBye() timing.expectedSleep = 200 * time.Millisecond - if err := client.SendHelloResume("this-is-invalid"); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloResume("this-is-invalid")) - if msg, err := client.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", msg.Error.Code) + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } } @@ -1682,37 +1435,26 @@ func TestClientHelloResumeThrottle(t *testing.T) { func TestClientHelloResumeExpired(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client.Close() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) // Perform housekeeping in the future, this will cause the session to be // cleaned up after it is expired. @@ -1721,17 +1463,11 @@ func TestClientHelloResumeExpired(t *testing.T) { client = NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } - msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", msg.Error.Code) + require.NoError(client.SendHelloResume(hello.Hello.ResumeId)) + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } } @@ -1739,107 +1475,72 @@ func TestClientHelloResumeExpired(t *testing.T) { func TestClientHelloResumeTakeover(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client1.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + require.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloResume(hello.Hello.ResumeId)) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) // The first client got disconnected with a reason in a "Bye" message. - msg, err := client1.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "bye" || msg.Bye == nil { - t.Errorf("Expected bye message, got %+v", msg) - } else if msg.Bye.Reason != "session_resumed" { - t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("bye", msg.Type, "%+v", msg) + if assert.NotNil(msg.Bye, "%+v", msg) { + assert.Equal("session_resumed", msg.Bye.Reason, "%+v", msg) } } if msg, err := client1.RunUntilMessage(ctx); err == nil { - t.Errorf("Expected error but received %+v", msg) + assert.Fail("Expected error but received %+v", msg) } else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - t.Errorf("Expected close error but received %+v", err) + assert.Fail("Expected close error but received %+v", err) } } func TestClientHelloResumeOtherHub(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client.Close() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) // Simulate a restart of the hub. hub.sid.Store(0) @@ -1855,46 +1556,28 @@ func TestClientHelloResumeOtherHub(t *testing.T) { hub.mu.Lock() count := len(hub.sessions) hub.mu.Unlock() - if count > 0 { - t.Errorf("Should have removed all sessions (still has %d)", count) - } + assert.Equal(0, count, "Should have removed all sessions") // The new client will get the same (internal) sid for his session. newClient := NewTestClient(t, server, hub) defer newClient.CloseWithBye() - if err := newClient.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(newClient.SendHello(testDefaultUserId)) - if hello, err := newClient.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + hello2, err := newClient.RunUntilHello(ctx) + require.NoError(err) + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.NotEmpty(hello2.Hello.ResumeId, "%+v", hello2.Hello) // The previous session (which had the same internal sid) can't be resumed. client = NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } - msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", msg.Error.Code) + require.NoError(client.SendHelloResume(hello.Hello.ResumeId)) + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } @@ -1905,35 +1588,27 @@ func TestClientHelloResumeOtherHub(t *testing.T) { func TestClientHelloResumePublicId(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) // Test that a client can't resume a "public" session of another user. hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) recipient2 := MessageClientMessageRecipient{ Type: "session", @@ -1945,32 +1620,22 @@ func TestClientHelloResumePublicId(t *testing.T) { var payload string var sender *MessageServerMessageSender - if err := checkReceiveClientMessageWithSender(ctx, client2, "session", hello1.Hello, &payload, &sender); err != nil { - t.Error(err) - } else if payload != data { - t.Errorf("Expected payload %s, got %s", data, payload) + if err := checkReceiveClientMessageWithSender(ctx, client2, "session", hello1.Hello, &payload, &sender); assert.NoError(err) { + assert.Equal(data, payload) } client1.Close() - if err := client1.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client1.WaitForClientRemoved(ctx)) client1 = NewTestClient(t, server, hub) defer client1.CloseWithBye() // Can't resume a session with the id received from messages of a client. - if err := client1.SendHelloResume(sender.SessionId); err != nil { - t.Fatal(err) - } - msg, err := client1.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", msg.Error.Code) + require.NoError(client1.SendHelloResume(sender.SessionId)) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } @@ -1981,66 +1646,41 @@ func TestClientHelloResumePublicId(t *testing.T) { func TestClientHelloByeResume(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) - if err := client.SendBye(); err != nil { - t.Fatal(err) - } - if message, err := client.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { - if err := checkMessageType(message, "bye"); err != nil { - t.Error(err) - } + require.NoError(client.SendBye()) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageType(message, "bye")) } client.Close() - if err := client.WaitForSessionRemoved(ctx, hello.Hello.SessionId); err != nil { - t.Error(err) - } - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForSessionRemoved(ctx, hello.Hello.SessionId)) + assert.NoError(client.WaitForClientRemoved(ctx)) client = NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } - msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected \"error\", got %+v", *msg) - } else if msg.Error.Code != "no_such_session" { - t.Errorf("Expected error \"no_such_session\", got %+v", *msg) + require.NoError(client.SendHelloResume(hello.Hello.ResumeId)) + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("error", msg.Type, "%+v", msg) + if assert.NotNil(msg.Error, "%+v", msg) { + assert.Equal("no_such_session", msg.Error.Code, "%+v", msg) } } } @@ -2048,66 +1688,42 @@ func TestClientHelloByeResume(t *testing.T) { func TestClientHelloResumeAndJoin(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client.Close() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) client = NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloResume(hello.Hello.ResumeId)) hello2, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) } func runGrpcProxyTest(t *testing.T, f func(hub1, hub2 *Hub, server1, server2 *httptest.Server)) { @@ -2152,76 +1768,48 @@ func TestClientHelloResumeProxy(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { + require := require.New(t) + assert := assert.New(t) client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client1.Close() - if err := client1.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client1.WaitForClientRemoved(ctx)) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloResume(hello.Hello.ResumeId)) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) // Join room by id. roomId := "test-room" - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client2.RunUntilJoined(ctx, hello.Hello)) - if room := hub1.getRoom(roomId); room == nil { - t.Fatalf("Could not find room %s", roomId) - } - if room := hub2.getRoom(roomId); room != nil { - t.Fatalf("Should not have gotten room %s, got %+v", roomId, room) - } + room := hub1.getRoom(roomId) + require.NotNil(room, "Could not find room %s", roomId) + room2 := hub2.getRoom(roomId) + require.Nil(room2, "Should not have gotten room %s", roomId) users := []map[string]interface{}{ { @@ -2229,14 +1817,8 @@ func TestClientHelloResumeProxy(t *testing.T) { "inCall": 1, }, } - room := hub1.getRoom(roomId) - if room == nil { - t.Fatalf("Could not find room %s", roomId) - } room.PublishUsersInCallChanged(users, users) - if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil { - t.Error(err) - } + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) }) }) } @@ -2245,105 +1827,68 @@ func TestClientHelloResumeProxy_Takeover(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { + require := require.New(t) + assert := assert.New(t) client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + require.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloResume(hello.Hello.ResumeId)) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) // The first client got disconnected with a reason in a "Bye" message. - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { - if msg.Type != "bye" || msg.Bye == nil { - t.Errorf("Expected bye message, got %+v", msg) - } else if msg.Bye.Reason != "session_resumed" { - t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("bye", msg.Type, "%+v", msg) + if assert.NotNil(msg.Bye, "%+v", msg) { + assert.Equal("session_resumed", msg.Bye.Reason, "%+v", msg) } } if msg, err := client1.RunUntilMessage(ctx); err == nil { - t.Errorf("Expected error but received %+v", msg) + assert.Fail("Expected error but received %+v", msg) } else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - t.Errorf("Expected close error but received %+v", err) + assert.Fail("Expected close error but received %+v", err) } client3 := NewTestClient(t, server1, hub1) defer client3.CloseWithBye() - if err := client3.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } + require.NoError(client3.SendHelloResume(hello.Hello.ResumeId)) hello3, err := client3.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello3.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello3.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello3.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello3.Hello.UserId, "%+v", hello3.Hello) + assert.Equal(hello.Hello.SessionId, hello3.Hello.SessionId, "%+v", hello3.Hello) + assert.Equal(hello.Hello.ResumeId, hello3.Hello.ResumeId, "%+v", hello3.Hello) // The second client got disconnected with a reason in a "Bye" message. - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { - if msg.Type != "bye" || msg.Bye == nil { - t.Errorf("Expected bye message, got %+v", msg) - } else if msg.Bye.Reason != "session_resumed" { - t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("bye", msg.Type, "%+v", msg) + if assert.NotNil(msg.Bye, "%+v", msg) { + assert.Equal("session_resumed", msg.Bye.Reason, "%+v", msg) } } if msg, err := client2.RunUntilMessage(ctx); err == nil { - t.Errorf("Expected error but received %+v", msg) + assert.Fail("Expected error but received %+v", msg) } else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - t.Errorf("Expected close error but received %+v", err) + assert.Fail("Expected close error but received %+v", err) } }) }) @@ -2353,63 +1898,39 @@ func TestClientHelloResumeProxy_Disconnect(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { + require := require.New(t) + assert := assert.New(t) client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) client1.Close() - if err := client1.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client1.WaitForClientRemoved(ctx)) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloResume(hello.Hello.ResumeId)) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } - } + require.NoError(err) + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) // Simulate unclean shutdown of second instance. hub2.rpcServer.conn.Stop() - if err := client2.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client2.WaitForClientRemoved(ctx)) }) }) } @@ -2417,36 +1938,30 @@ func TestClientHelloResumeProxy_Disconnect(t *testing.T) { func TestClientHelloClient(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloClient(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloClient(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) } } func TestClientHelloClient_V3Api(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -2457,55 +1972,37 @@ func TestClientHelloClient_V3Api(t *testing.T) { } // The "/api/v1/signaling/" URL will be changed to use "v3" as the "signaling-v3" // feature is returned by the capabilities endpoint. - if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", HelloVersionV1, "client", nil, params); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", HelloVersionV1, "client", nil, params)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) } } func TestClientHelloInternal(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHelloInternal(); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHelloInternal()) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != "" { - t.Errorf("Expected empty user id, got %+v", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Empty(hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) } } @@ -2514,6 +2011,8 @@ func TestClientMessageToSessionId(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -2529,45 +2028,31 @@ func TestClientMessageToSessionId(t *testing.T) { } mcu1, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } + require.NoError(err) hub1.SetMcu(mcu1) if hub1 != hub2 { mcu2, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } + require.NoError(err) hub2.SetMcu(mcu2) } client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) recipient1 := MessageClientMessageRecipient{ Type: "session", @@ -2587,16 +2072,12 @@ func TestClientMessageToSessionId(t *testing.T) { client2.SendMessage(recipient1, data2) // nolint var payload1 string - if err := checkReceiveClientMessage(ctx, client1, "session", hello2.Hello, &payload1); err != nil { - t.Error(err) - } else if payload1 != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload1) + if err := checkReceiveClientMessage(ctx, client1, "session", hello2.Hello, &payload1); assert.NoError(err) { + assert.Equal(data2, payload1) } var payload2 map[string]interface{} - if err := checkReceiveClientMessage(ctx, client2, "session", hello1.Hello, &payload2); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(data1, payload2) { - t.Errorf("Expected payload %+v, got %+v", data1, payload2) + if err := checkReceiveClientMessage(ctx, client2, "session", hello1.Hello, &payload2); assert.NoError(err) { + assert.Equal(data1, payload2) } }) } @@ -2607,6 +2088,8 @@ func TestClientControlToSessionId(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -2623,30 +2106,20 @@ func TestClientControlToSessionId(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) recipient1 := MessageClientMessageRecipient{ Type: "session", @@ -2663,15 +2136,11 @@ func TestClientControlToSessionId(t *testing.T) { client2.SendControl(recipient1, data2) // nolint var payload string - if err := checkReceiveClientControl(ctx, client1, "session", hello2.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload) + if err := checkReceiveClientControl(ctx, client1, "session", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) } - if err := checkReceiveClientControl(ctx, client2, "session", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientControl(ctx, client2, "session", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } }) } @@ -2680,43 +2149,31 @@ func TestClientControlToSessionId(t *testing.T) { func TestClientControlMissingPermissions(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) - if session2 == nil { - t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) - } + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) // Client 1 may not send control messages (will be ignored). session1.SetPermissions([]Permission{ @@ -2745,57 +2202,44 @@ func TestClientControlMissingPermissions(t *testing.T) { client2.SendControl(recipient1, data2) // nolint var payload string - if err := checkReceiveClientControl(ctx, client1, "session", hello2.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload) + if err := checkReceiveClientControl(ctx, client1, "session", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) } ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if err := checkReceiveClientMessage(ctx2, client2, "session", hello1.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx2, client2, "session", hello1.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } } func TestClientMessageToUserId(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } else if hello1.Hello.UserId == hello2.Hello.UserId { - t.Fatalf("Expected different user ids, got %s twice", hello1.Hello.UserId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) recipient1 := MessageClientMessageRecipient{ Type: "user", @@ -2812,52 +2256,39 @@ func TestClientMessageToUserId(t *testing.T) { client2.SendMessage(recipient1, data2) // nolint var payload string - if err := checkReceiveClientMessage(ctx, client1, "user", hello2.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload) + if err := checkReceiveClientMessage(ctx, client1, "user", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) } - if err := checkReceiveClientMessage(ctx, client2, "user", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientMessage(ctx, client2, "user", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } } func TestClientControlToUserId(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } else if hello1.Hello.UserId == hello2.Hello.UserId { - t.Fatalf("Expected different user ids, got %s twice", hello1.Hello.UserId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) recipient1 := MessageClientMessageRecipient{ Type: "user", @@ -2874,70 +2305,49 @@ func TestClientControlToUserId(t *testing.T) { client2.SendControl(recipient1, data2) // nolint var payload string - if err := checkReceiveClientControl(ctx, client1, "user", hello2.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload) + if err := checkReceiveClientControl(ctx, client1, "user", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) } - if err := checkReceiveClientControl(ctx, client2, "user", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientControl(ctx, client2, "user", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } } func TestClientMessageToUserIdMultipleSessions(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2a := NewTestClient(t, server, hub) defer client2a.CloseWithBye() - if err := client2a.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2a.SendHello(testDefaultUserId + "2")) client2b := NewTestClient(t, server, hub) defer client2b.CloseWithBye() - if err := client2b.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2b.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2a, err := client2a.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2b, err := client2b.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2a.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } else if hello1.Hello.SessionId == hello2b.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } else if hello2a.Hello.SessionId == hello2b.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello2a.Hello.SessionId) - } - if hello1.Hello.UserId == hello2a.Hello.UserId { - t.Fatalf("Expected different user ids, got %s twice", hello1.Hello.UserId) - } else if hello1.Hello.UserId == hello2b.Hello.UserId { - t.Fatalf("Expected different user ids, got %s twice", hello1.Hello.UserId) - } else if hello2a.Hello.UserId != hello2b.Hello.UserId { - t.Fatalf("Expected the same user ids, got %s and %s", hello2a.Hello.UserId, hello2b.Hello.UserId) - } + require.NotEqual(hello1.Hello.SessionId, hello2a.Hello.SessionId) + require.NotEqual(hello1.Hello.SessionId, hello2b.Hello.SessionId) + require.NotEqual(hello2a.Hello.SessionId, hello2b.Hello.SessionId) + + require.NotEqual(hello1.Hello.UserId, hello2a.Hello.UserId) + require.NotEqual(hello1.Hello.UserId, hello2b.Hello.UserId) + require.Equal(hello2a.Hello.UserId, hello2b.Hello.UserId) recipient := MessageClientMessageRecipient{ Type: "user", @@ -2949,27 +2359,19 @@ func TestClientMessageToUserIdMultipleSessions(t *testing.T) { // Both clients will receive the message as it was sent to the user. var payload string - if err := checkReceiveClientMessage(ctx, client2a, "user", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientMessage(ctx, client2a, "user", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } - if err := checkReceiveClientMessage(ctx, client2b, "user", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientMessage(ctx, client2b, "user", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } } func WaitForUsersJoined(ctx context.Context, t *testing.T, client1 *TestClient, hello1 *ServerMessage, client2 *TestClient, hello2 *ServerMessage) { // We will receive "joined" events for all clients. The ordering is not // defined as messages are processed and sent by asynchronous event handlers. - if err := client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } - if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(t, client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) + assert.NoError(t, client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) } func TestClientMessageToRoom(t *testing.T) { @@ -2977,6 +2379,8 @@ func TestClientMessageToRoom(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -2996,46 +2400,31 @@ func TestClientMessageToRoom(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } else if hello1.Hello.UserId == hello2.Hello.UserId { - t.Fatalf("Expected different user ids, got %s twice", hello1.Hello.UserId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) @@ -3049,16 +2438,12 @@ func TestClientMessageToRoom(t *testing.T) { client2.SendMessage(recipient, data2) // nolint var payload string - if err := checkReceiveClientMessage(ctx, client1, "room", hello2.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload) + if err := checkReceiveClientMessage(ctx, client1, "room", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) } - if err := checkReceiveClientMessage(ctx, client2, "room", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientMessage(ctx, client2, "room", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } }) } @@ -3069,6 +2454,8 @@ func TestClientControlToRoom(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -3088,46 +2475,31 @@ func TestClientControlToRoom(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } else if hello1.Hello.UserId == hello2.Hello.UserId { - t.Fatalf("Expected different user ids, got %s twice", hello1.Hello.UserId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) @@ -3141,16 +2513,12 @@ func TestClientControlToRoom(t *testing.T) { client2.SendControl(recipient, data2) // nolint var payload string - if err := checkReceiveClientControl(ctx, client1, "room", hello2.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data2 { - t.Errorf("Expected payload %s, got %s", data2, payload) + if err := checkReceiveClientControl(ctx, client1, "room", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) } - if err := checkReceiveClientControl(ctx, client2, "room", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if payload != data1 { - t.Errorf("Expected payload %s, got %s", data1, payload) + if err := checkReceiveClientControl(ctx, client2, "room", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } }) } @@ -3159,78 +2527,63 @@ func TestClientControlToRoom(t *testing.T) { func TestJoinRoom(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Leave room. - if room, err := client.JoinRoom(ctx, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != "" { - t.Fatalf("Expected empty room, got %s", room.Room.RoomId) - } + roomMsg, err = client.JoinRoom(ctx, "") + require.NoError(err) + require.Equal("", roomMsg.Room.RoomId) } func TestJoinRoomTwice(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } else if !bytes.Equal(testRoomProperties, room.Room.Properties) { - t.Fatalf("Expected room properties %s, got %s", string(testRoomProperties), string(room.Room.Properties)) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + require.Equal(string(testRoomProperties), string(roomMsg.Room.Properties)) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) msg := &ClientMessage{ Id: "ABCD", @@ -3240,138 +2593,94 @@ func TestJoinRoomTwice(t *testing.T) { SessionId: roomId + "-" + client.publicId + "-2", }, } - if err := client.WriteJSON(msg); err != nil { - t.Fatal(err) - } + require.NoError(client.WriteJSON(msg)) message, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } - - if msg.Id != message.Id { - t.Errorf("expected message id %s, got %s", msg.Id, message.Id) - } else if err := checkMessageType(message, "error"); err != nil { - t.Fatal(err) - } else if expected := "already_joined"; message.Error.Code != expected { - t.Errorf("expected error %s, got %s", expected, message.Error.Code) - } else if message.Error.Details == nil { - t.Fatal("expected error details") - } + require.NoError(err) + require.NoError(checkUnexpectedClose(err)) - var roomMsg RoomErrorDetails - if err := json.Unmarshal(message.Error.Details, &roomMsg); err != nil { - t.Fatal(err) - } else if roomMsg.Room == nil { - t.Fatalf("expected room details, got %+v", message) + assert.Equal(msg.Id, message.Id) + if assert.NoError(checkMessageType(message, "error")) { + assert.Equal("already_joined", message.Error.Code) + assert.NotNil(message.Error.Details) } - if roomMsg.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %+v", roomId, roomMsg.Room) - } else if !bytes.Equal(testRoomProperties, roomMsg.Room.Properties) { - t.Fatalf("Expected room properties %s, got %s", string(testRoomProperties), string(roomMsg.Room.Properties)) + var roomDetails RoomErrorDetails + if json.Unmarshal(message.Error.Details, &roomDetails); assert.NoError(err) { + if assert.NotNil(roomDetails.Room, "%+v", message) { + assert.Equal(roomId, roomDetails.Room.RoomId) + assert.Equal(string(testRoomProperties), string(roomDetails.Room.Properties)) + } } } func TestExpectAnonymousJoinRoom(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(authAnonymousUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(authAnonymousUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != "" { - t.Errorf("Expected an anonymous user, got %+v", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } + if assert.NoError(err) { + assert.Empty(hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) } // Perform housekeeping in the future, this will cause the connection to // be terminated because the anonymous client didn't join a room. performHousekeeping(hub, time.Now().Add(anonmyousJoinRoomTimeout+time.Second)) - message, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - - if err := checkMessageType(message, "bye"); err != nil { - t.Error(err) - } else if message.Bye.Reason != "room_join_timeout" { - t.Errorf("Expected \"room_join_timeout\" reason, got %+v", message.Bye) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + if assert.NoError(checkMessageType(message, "bye")) { + assert.Equal("room_join_timeout", message.Bye.Reason, "%+v", message.Bye) + } } // Both the client and the session get removed from the hub. - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } - if err := client.WaitForSessionRemoved(ctx, hello.Hello.SessionId); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) + assert.NoError(client.WaitForSessionRemoved(ctx, hello.Hello.SessionId)) } func TestExpectAnonymousJoinRoomAfterLeave(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(authAnonymousUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(authAnonymousUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != "" { - t.Errorf("Expected an anonymous user, got %+v", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } + if assert.NoError(err) { + assert.Empty(hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.ResumeId, "%+v", hello.Hello) } // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Perform housekeeping in the future, this will keep the connection as the // session joined a room. @@ -3381,257 +2690,185 @@ func TestExpectAnonymousJoinRoomAfterLeave(t *testing.T) { ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel2() - if message, err := client.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } // Leave room - if room, err := client.JoinRoom(ctx, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != "" { - t.Fatalf("Expected room %s, got %s", "", room.Room.RoomId) - } + roomMsg, err = client.JoinRoom(ctx, "") + require.NoError(err) + require.Equal("", roomMsg.Room.RoomId) // Perform housekeeping in the future, this will cause the connection to // be terminated because the anonymous client didn't join a room. performHousekeeping(hub, time.Now().Add(anonmyousJoinRoomTimeout+time.Second)) - message, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - - if err := checkMessageType(message, "bye"); err != nil { - t.Error(err) - } else if message.Bye.Reason != "room_join_timeout" { - t.Errorf("Expected \"room_join_timeout\" reason, got %+v", message.Bye) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + if assert.NoError(checkMessageType(message, "bye")) { + assert.Equal("room_join_timeout", message.Bye.Reason, "%+v", message.Bye) + } } // Both the client and the session get removed from the hub. - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } - if err := client.WaitForSessionRemoved(ctx, hello.Hello.SessionId); err != nil { - t.Error(err) - } + assert.NoError(client.WaitForClientRemoved(ctx)) + assert.NoError(client.WaitForSessionRemoved(ctx, hello.Hello.SessionId)) } func TestJoinRoomChange(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Change room. roomId = "other-test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Leave room. - if room, err := client.JoinRoom(ctx, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != "" { - t.Fatalf("Expected empty room, got %s", room.Room.RoomId) - } + roomMsg, err = client.JoinRoom(ctx, "") + require.NoError(err) + require.Equal("", roomMsg.Room.RoomId) } func TestJoinMultiple(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) // Join room by id (first client). roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) // Join room by id (second client). - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event for the first and the second client. - if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) // The first client will also receive a "joined" event from the second client. - if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello2.Hello)) // Leave room. - if room, err := client1.JoinRoom(ctx, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != "" { - t.Fatalf("Expected empty room, got %s", room.Room.RoomId) - } + roomMsg, err = client1.JoinRoom(ctx, "") + require.NoError(err) + require.Equal("", roomMsg.Room.RoomId) // The second client will now receive a "left" event - if err := client2.RunUntilLeft(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client2.RunUntilLeft(ctx, hello1.Hello)) - if room, err := client2.JoinRoom(ctx, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != "" { - t.Fatalf("Expected empty room, got %s", room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, "") + require.NoError(err) + require.Equal("", roomMsg.Room.RoomId) } func TestJoinDisplaynamesPermission(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) - if session2 == nil { - t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) - } + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) // Client 2 may not receive display names. session2.SetPermissions([]Permission{PERMISSION_HIDE_DISPLAYNAMES}) // Join room by id (first client). roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) // Join room by id (second client). - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event for the first and the second client. - if events, unexpected, err := client2.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } else { - if len(unexpected) > 0 { - t.Errorf("Received unexpected messages: %+v", unexpected) - } else if len(events) != 2 { - t.Errorf("Expected two event, got %+v", events) - } else if events[0].User != nil { - t.Errorf("Expected empty userdata for first event, got %+v", events[0].User) - } else if events[1].User != nil { - t.Errorf("Expected empty userdata for second event, got %+v", events[1].User) + if events, unexpected, err := client2.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello); assert.NoError(err) { + assert.Empty(unexpected) + if assert.Len(events, 2) { + assert.Nil(events[0].User) + assert.Nil(events[1].User) } } // The first client will also receive a "joined" event from the second client. - if events, unexpected, err := client1.RunUntilJoinedAndReturn(ctx, hello2.Hello); err != nil { - t.Error(err) - } else { - if len(unexpected) > 0 { - t.Errorf("Received unexpected messages: %+v", unexpected) - } else if len(events) != 1 { - t.Errorf("Expected one event, got %+v", events) - } else if events[0].User == nil { - t.Errorf("Expected userdata for first event, got nothing") + if events, unexpected, err := client1.RunUntilJoinedAndReturn(ctx, hello2.Hello); assert.NoError(err) { + assert.Empty(unexpected) + if assert.Len(events, 1) { + assert.NotNil(events[0].User) } } } @@ -3639,6 +2876,8 @@ func TestJoinDisplaynamesPermission(t *testing.T) { func TestInitialRoomPermissions(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -3647,59 +2886,43 @@ func TestInitialRoomPermissions(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId + "1")) hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room-initial-permissions" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) session := hub.GetSessionByPublicId(hello.Hello.SessionId).(*ClientSession) - if session == nil { - t.Fatalf("Session %s does not exist", hello.Hello.SessionId) - } + require.NotNil(session, "Session %s does not exist", hello.Hello.SessionId) - if !session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO) { - t.Errorf("Session %s should have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_AUDIO, session.permissions) - } - if session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO) { - t.Errorf("Session %s should not have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_VIDEO, session.permissions) - } + assert.True(session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO), "Session %s should have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_AUDIO, session.permissions) + assert.False(session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO), "Session %s should not have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_VIDEO, session.permissions) } func TestJoinRoomSwitchClient(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room-slow" @@ -3711,64 +2934,37 @@ func TestJoinRoomSwitchClient(t *testing.T) { SessionId: roomId + "-" + hello.Hello.SessionId, }, } - if err := client.WriteJSON(msg); err != nil { - t.Fatal(err) - } + require.NoError(client.WriteJSON(msg)) // Wait a bit to make sure request is sent before closing client. time.Sleep(1 * time.Millisecond) client.Close() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Fatal(err) - } + require.NoError(client.WaitForClientRemoved(ctx)) // The client needs some time to reconnect. time.Sleep(200 * time.Millisecond) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { - t.Fatal(err) - } - hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello2.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) - } - if hello2.Hello.SessionId != hello.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) - } - if hello2.Hello.ResumeId != hello.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) - } + require.NoError(client2.SendHelloResume(hello.Hello.ResumeId)) + if hello2, err := client2.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello2.Hello.UserId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.SessionId, hello2.Hello.SessionId, "%+v", hello2.Hello) + assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello) } - room, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - if err := checkUnexpectedClose(err); err != nil { - t.Fatal(err) - } - if err := checkMessageType(room, "room"); err != nil { - t.Fatal(err) - } - if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client2.RunUntilMessage(ctx) + require.NoError(err) + require.NoError(checkUnexpectedClose(err)) + require.NoError(checkMessageType(roomMsg, "room")) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client2.RunUntilJoined(ctx, hello.Hello)) // Leave room. - if room, err := client2.JoinRoom(ctx, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != "" { - t.Fatalf("Expected empty room, got %s", room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, "") + require.NoError(err) + require.Equal("", roomMsg.Room.RoomId) } func TestGetRealUserIP(t *testing.T) { @@ -3986,56 +3182,43 @@ func TestGetRealUserIP(t *testing.T) { for _, tc := range testcases { trustedProxies, err := ParseAllowedIps(tc.trusted) - if err != nil { - t.Errorf("invalid trusted proxies in %+v: %s", tc, err) + if !assert.NoError(t, err, "invalid trusted proxies in %+v", tc) { continue } request := &http.Request{ RemoteAddr: tc.addr, Header: tc.headers, } - if ip := GetRealUserIP(request, trustedProxies); ip != tc.expected { - t.Errorf("Expected %s for %+v but got %s", tc.expected, tc, ip) - } + assert.Equal(t, tc.expected, GetRealUserIP(request, trustedProxies), "failed for %+v", tc) } } func TestClientMessageToSessionIdWhileDisconnected(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) client2.Close() - if err := client2.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client2.WaitForClientRemoved(ctx)) recipient2 := MessageClientMessageRecipient{ Type: "session", @@ -4045,9 +3228,7 @@ func TestClientMessageToSessionIdWhileDisconnected(t *testing.T) { // The two chat messages should get combined into one when receiving pending messages. chat_refresh := "{\"type\":\"chat\",\"chat\":{\"refresh\":true}}" var data1 map[string]interface{} - if err := json.Unmarshal([]byte(chat_refresh), &data1); err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal([]byte(chat_refresh), &data1)) client1.SendMessage(recipient2, data1) // nolint client1.SendMessage(recipient2, data1) // nolint @@ -4056,91 +3237,64 @@ func TestClientMessageToSessionIdWhileDisconnected(t *testing.T) { client2 = NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello2.Hello.ResumeId); err != nil { - t.Fatal(err) - } - hello3, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello3.Hello.UserId != testDefaultUserId+"2" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello3.Hello) - } - if hello3.Hello.SessionId != hello2.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, hello3.Hello) - } - if hello3.Hello.ResumeId != hello2.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello2.Hello.ResumeId, hello3.Hello) - } + require.NoError(client2.SendHelloResume(hello2.Hello.ResumeId)) + if hello3, err := client2.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId+"2", hello3.Hello.UserId, "%+v", hello3.Hello) + assert.Equal(hello2.Hello.SessionId, hello3.Hello.SessionId, "%+v", hello3.Hello) + assert.Equal(hello2.Hello.ResumeId, hello3.Hello.ResumeId, "%+v", hello3.Hello) } var payload map[string]interface{} - if err := checkReceiveClientMessage(ctx, client2, "session", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(payload, data1) { - t.Errorf("Expected payload %+v, got %+v", data1, payload) + if err := checkReceiveClientMessage(ctx, client2, "session", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if err := checkReceiveClientMessage(ctx2, client2, "session", hello1.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx2, client2, "session", hello1.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } } func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if hello1.Hello.SessionId == hello2.Hello.SessionId { - t.Fatalf("Expected different session ids, got %s twice", hello1.Hello.SessionId) - } + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) @@ -4152,18 +3306,12 @@ func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { }, } room := hub.getRoom(roomId) - if room == nil { - t.Fatalf("Could not find room %s", roomId) - } + require.NotNil(room, "Could not find room %s", roomId) room.PublishUsersInCallChanged(users, users) - if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil { - t.Error(err) - } + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) client2.Close() - if err := client2.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + assert.NoError(client2.WaitForClientRemoved(ctx)) room.PublishUsersInCallChanged(users, users) @@ -4177,53 +3325,34 @@ func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { chat_refresh := "{\"type\":\"chat\",\"chat\":{\"refresh\":true}}" var data1 map[string]interface{} - if err := json.Unmarshal([]byte(chat_refresh), &data1); err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal([]byte(chat_refresh), &data1)) client1.SendMessage(recipient2, data1) // nolint client2 = NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHelloResume(hello2.Hello.ResumeId); err != nil { - t.Fatal(err) - } - hello3, err := client2.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if hello3.Hello.UserId != testDefaultUserId+"2" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello3.Hello) - } - if hello3.Hello.SessionId != hello2.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, hello3.Hello) - } - if hello3.Hello.ResumeId != hello2.Hello.ResumeId { - t.Errorf("Expected resume id %s, got %+v", hello2.Hello.ResumeId, hello3.Hello) - } + require.NoError(client2.SendHelloResume(hello2.Hello.ResumeId)) + if hello3, err := client2.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId+"2", hello3.Hello.UserId, "%+v", hello3.Hello) + assert.Equal(hello2.Hello.SessionId, hello3.Hello.SessionId, "%+v", hello3.Hello) + assert.Equal(hello2.Hello.ResumeId, hello3.Hello.ResumeId, "%+v", hello3.Hello) } // The participants list update event is triggered again after the session resume. // TODO(jojo): Check contents of message and try with multiple users. - if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil { - t.Error(err) - } + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) var payload map[string]interface{} - if err := checkReceiveClientMessage(ctx, client2, "session", hello1.Hello, &payload); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(payload, data1) { - t.Errorf("Expected payload %+v, got %+v", data1, payload) + if err := checkReceiveClientMessage(ctx, client2, "session", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) } ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if err := checkReceiveClientMessage(ctx2, client2, "session", hello1.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx2, client2, "session", hello1.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } } @@ -4238,6 +3367,8 @@ func TestClientTakeoverRoomSession(t *testing.T) { } func RunTestClientTakeoverRoomSession(t *testing.T) { + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -4254,52 +3385,38 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room-takeover-room-session" roomSessionid := "room-session-id" - if room, err := client1.JoinRoomWithRoomSession(ctx, roomId, roomSessionid); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoomWithRoomSession(ctx, roomId, roomSessionid) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if hubRoom := hub1.getRoom(roomId); hubRoom == nil { - t.Fatalf("Room %s does not exist", roomId) - } + hubRoom := hub1.getRoom(roomId) + require.NotNil(hubRoom, "Room %s does not exist", roomId) - if session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId); session1 == nil { - t.Fatalf("There should be a session %s", hello1.Hello.SessionId) - } + session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId) + require.NotNil(session1, "There should be a session %s", hello1.Hello.SessionId) client3 := NewTestClient(t, server2, hub2) defer client3.CloseWithBye() - if err := client3.SendHello(testDefaultUserId + "3"); err != nil { - t.Fatal(err) - } + require.NoError(client3.SendHello(testDefaultUserId + "3")) hello3, err := client3.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if room, err := client3.JoinRoomWithRoomSession(ctx, roomId, roomSessionid+"other"); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client3.JoinRoomWithRoomSession(ctx, roomId, roomSessionid+"other") + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Wait until both users have joined. WaitForUsersJoined(ctx, t, client1, hello1, client3, hello3) @@ -4307,85 +3424,67 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if room, err := client2.JoinRoomWithRoomSession(ctx, roomId, roomSessionid); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoomWithRoomSession(ctx, roomId, roomSessionid) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // The first client got disconnected with a reason in a "Bye" message. - msg, err := client1.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "bye" || msg.Bye == nil { - t.Errorf("Expected bye message, got %+v", msg) - } else if msg.Bye.Reason != "room_session_reconnected" { - t.Errorf("Expected reason \"room_session_reconnected\", got %+v", msg.Bye.Reason) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("bye", msg.Type, "%+v", msg) + if assert.NotNil(msg.Bye, "%+v", msg) { + assert.Equal("room_session_reconnected", msg.Bye.Reason, "%+v", msg) } } if msg, err := client1.RunUntilMessage(ctx); err == nil { - t.Errorf("Expected error but received %+v", msg) + assert.Fail("Expected error but received %+v", msg) } else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - t.Errorf("Expected close error but received %+v", err) + assert.Fail("Expected close error but received %+v", err) } // The first session has been closed - if session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId); session1 != nil { - t.Errorf("The session %s should have been removed", hello1.Hello.SessionId) - } + session1 = hub1.GetSessionByPublicId(hello1.Hello.SessionId) + assert.Nil(session1, "The session %s should have been removed", hello1.Hello.SessionId) // The new client will receive "joined" events for the existing client3 and // himself. - if err := client2.RunUntilJoined(ctx, hello3.Hello, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client2.RunUntilJoined(ctx, hello3.Hello, hello2.Hello)) // No message about the closing is sent to the new connection. ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel2() - if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } // The permanently connected client will receive a "left" event from the // overridden session and a "joined" for the new session. In that order as // both were on the same server. - if err := client3.RunUntilLeft(ctx, hello1.Hello); err != nil { - t.Error(err) - } - if err := client3.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client3.RunUntilLeft(ctx, hello1.Hello)) + assert.NoError(client3.RunUntilJoined(ctx, hello2.Hello)) } func TestClientSendOfferPermissions(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -4393,54 +3492,38 @@ func TestClientSendOfferPermissions(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) - if session2 == nil { - t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) - } + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) // Client 1 is the moderator session1.SetPermissions([]Permission{PERMISSION_MAY_PUBLISH_MEDIA, PERMISSION_MAY_PUBLISH_SCREEN}) @@ -4448,26 +3531,20 @@ func TestClientSendOfferPermissions(t *testing.T) { session2.SetPermissions([]Permission{}) // Client 2 may not send an offer (he doesn't have the necessary permissions). - if err := client2.SendMessage(MessageClientMessageRecipient{ + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ Type: "sendoffer", Sid: "12345", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } - } + msg, err := client2.RunUntilMessage(ctx) + require.NoError(err) + require.NoError(checkMessageError(msg, "not_allowed")) - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -4477,56 +3554,47 @@ func TestClientSendOfferPermissions(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) // Client 1 may send an offer. - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello2.Hello.SessionId, }, MessageClientMessageData{ Type: "sendoffer", Sid: "54321", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) // The sender won't get a reply... ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel2() - if message, err := client1.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client1.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } // ...but the other peer will get an offer. - if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo)) } func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -4534,37 +3602,27 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotEmpty(session1, "Session %s does not exist", hello1.Hello.SessionId) // Client is allowed to send audio only. session1.SetPermissions([]Permission{PERMISSION_MAY_PUBLISH_AUDIO}) // Client may not send an offer with audio and video. - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -4574,20 +3632,14 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } - } + msg, err := client1.RunUntilMessage(ctx) + require.NoError(err) + require.NoError(checkMessageError(msg, "not_allowed")) // Client may send an offer (audio only). - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -4597,29 +3649,24 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioOnly, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioOnly); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioOnly)) } func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -4627,36 +3674,26 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) // Client is allowed to send audio and video. session1.SetPermissions([]Permission{PERMISSION_MAY_PUBLISH_AUDIO, PERMISSION_MAY_PUBLISH_VIDEO}) - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -4666,13 +3703,9 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) // Client is no longer allowed to send video, this will stop the publisher. msg := &BackendServerRoomRequest{ @@ -4694,34 +3727,25 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) ctx2, cancel2 := context.WithTimeout(ctx, time.Second) defer cancel2() pubs := mcu.GetPublishers() - if len(pubs) != 1 { - t.Fatalf("expected one publisher, got %+v", pubs) - } + require.Len(pubs, 1) loop: for { if err := ctx2.Err(); err != nil { - t.Errorf("publisher was not closed: %s", err) + assert.NoError(err, "publisher was not closed") + break } for _, pub := range pubs { @@ -4738,17 +3762,16 @@ loop: func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -4756,37 +3779,27 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) // Client is allowed to send audio and video. session1.SetPermissions([]Permission{PERMISSION_MAY_PUBLISH_MEDIA}) // Client may send an offer (audio and video). - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -4796,13 +3809,9 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) // Client is no longer allowed to send video, this will stop the publisher. msg := &BackendServerRoomRequest{ @@ -4824,42 +3833,31 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel2() pubs := mcu.GetPublishers() - if len(pubs) != 1 { - t.Fatalf("expected one publisher, got %+v", pubs) - } + require.Len(pubs, 1) loop: for { if err := ctx2.Err(); err != nil { if err != context.DeadlineExceeded { - t.Errorf("error while waiting for publisher: %s", err) + assert.Fail("error while waiting for publisher") } break } for _, pub := range pubs { - if pub.isClosed() { - t.Errorf("publisher was closed") + if !assert.False(pub.isClosed(), "publisher was closed") { break loop } } @@ -4874,6 +3872,8 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -4891,11 +3891,8 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub1.SetMcu(mcu) @@ -4904,41 +3901,29 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoomWithRoomSession(ctx, roomId, "roomsession1"); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoomWithRoomSession(ctx, roomId, "roomsession1") + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -4948,67 +3933,45 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) // Client 2 may not request an offer (he is not in the room yet). - if err := client2.SendMessage(MessageClientMessageRecipient{ + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ Type: "requestoffer", Sid: "12345", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } - } + msg, err := client2.RunUntilMessage(ctx) + require.NoError(err) + require.NoError(checkMessageError(msg, "not_allowed")) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } - if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } + require.NoError(client1.RunUntilJoined(ctx, hello2.Hello)) + require.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) // Client 2 may not request an offer (he is not in the call yet). - if err := client2.SendMessage(MessageClientMessageRecipient{ + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ Type: "requestoffer", Sid: "12345", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } - } + msg, err = client2.RunUntilMessage(ctx) + require.NoError(err) + require.NoError(checkMessageError(msg, "not_allowed")) // Simulate request from the backend that somebody joined the call. users1 := []map[string]interface{}{ @@ -5018,36 +3981,24 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { }, } room2 := hub2.getRoom(roomId) - if room2 == nil { - t.Fatalf("Could not find room %s", roomId) - } + require.NotNil(room2, "Could not find room %s", roomId) room2.PublishUsersInCallChanged(users1, users1) - if err := checkReceiveClientEvent(ctx, client1, "update", nil); err != nil { - t.Error(err) - } - if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil { - t.Error(err) - } + assert.NoError(checkReceiveClientEvent(ctx, client1, "update", nil)) + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) // Client 2 may not request an offer (recipient is not in the call yet). - if err := client2.SendMessage(MessageClientMessageRecipient{ + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ Type: "requestoffer", Sid: "12345", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } - } + msg, err = client2.RunUntilMessage(ctx) + require.NoError(err) + require.NoError(checkMessageError(msg, "not_allowed")) // Simulate request from the backend that somebody joined the call. users2 := []map[string]interface{}{ @@ -5057,34 +4008,24 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { }, } room1 := hub1.getRoom(roomId) - if room1 == nil { - t.Fatalf("Could not find room %s", roomId) - } + require.NotNil(room1, "Could not find room %s", roomId) room1.PublishUsersInCallChanged(users2, users2) - if err := checkReceiveClientEvent(ctx, client1, "update", nil); err != nil { - t.Error(err) - } - if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil { - t.Error(err) - } + assert.NoError(checkReceiveClientEvent(ctx, client1, "update", nil)) + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) // Client 2 may request an offer now (both are in the same room and call). - if err := client2.SendMessage(MessageClientMessageRecipient{ + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ Type: "requestoffer", Sid: "12345", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) - if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo)) - if err := client2.SendMessage(MessageClientMessageRecipient{ + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -5094,18 +4035,16 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpAnswerAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) // The sender won't get a reply... ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } }) } @@ -5114,6 +4053,8 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) // Clients can't send messages to sessions connected from other backends. hub, _, _, server := CreateHubWithMultipleBackendsForTest(t) @@ -5126,13 +4067,9 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: "user1", } - if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params1); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params1)) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() @@ -5140,13 +4077,9 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: "user2", } - if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", nil, params2); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", nil, params2)) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) recipient1 := MessageClientMessageRecipient{ Type: "session", @@ -5165,28 +4098,26 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { var payload string ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if err := checkReceiveClientMessage(ctx2, client1, "session", hello2.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx2, client1, "session", hello2.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel3() - if err := checkReceiveClientMessage(ctx3, client2, "session", hello1.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx3, client2, "session", hello1.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } } func TestNoSameRoomOnDifferentBackends(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubWithMultipleBackendsForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -5198,13 +4129,9 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: "user1", } - if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params1); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params1)) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() @@ -5212,40 +4139,24 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: "user2", } - if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", nil, params2); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", nil, params2)) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } - msg1, err := client1.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if err := client1.checkMessageJoined(msg1, hello1.Hello); err != nil { - t.Error(err) + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + if msg1, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client1.checkMessageJoined(msg1, hello1.Hello)) } - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } - msg2, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if err := client2.checkMessageJoined(msg2, hello2.Hello); err != nil { - t.Error(err) + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + if msg2, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client2.checkMessageJoined(msg2, hello2.Hello)) } hub.ru.RLock() @@ -5256,12 +4167,10 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { } hub.ru.RUnlock() - if len(rooms) != 2 { - t.Errorf("Expected 2 rooms, got %+v", rooms) - } - - if rooms[0].IsEqual(rooms[1]) { - t.Errorf("Rooms should be different: %+v", rooms) + if assert.Len(rooms, 2) { + if rooms[0].IsEqual(rooms[1]) { + assert.Fail("Rooms should be different: %+v", rooms) + } } recipient := MessageClientMessageRecipient{ @@ -5276,22 +4185,18 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { var payload string ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if err := checkReceiveClientMessage(ctx2, client1, "session", hello2.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx2, client1, "session", hello2.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel3() - if err := checkReceiveClientMessage(ctx3, client2, "session", hello1.Hello, &payload); err != nil { - if err != ErrNoMessageReceived { - t.Error(err) - } + if err := checkReceiveClientMessage(ctx3, client2, "session", hello1.Hello, &payload); err == nil { + assert.Fail("Expected no payload, got %+v", payload) } else { - t.Errorf("Expected no payload, got %+v", payload) + assert.ErrorIs(err, ErrNoMessageReceived) } } @@ -5300,6 +4205,8 @@ func TestClientSendOffer(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -5317,11 +4224,8 @@ func TestClientSendOffer(t *testing.T) { defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub1.SetMcu(mcu) @@ -5330,47 +4234,35 @@ func TestClientSendOffer(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoomWithRoomSession(ctx, roomId, "roomsession1"); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoomWithRoomSession(ctx, roomId, "roomsession1") + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -5380,38 +4272,30 @@ func TestClientSendOffer(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo)) - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello2.Hello.SessionId, }, MessageClientMessageData{ Type: "sendoffer", RoomType: "video", - }); err != nil { - t.Fatal(err) - } + })) // The sender won't get a reply... ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel2() - if message, err := client1.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client1.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } // ...but the other peer will get an offer. - if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil { - t.Fatal(err) - } + require.NoError(client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo)) }) } } @@ -5419,17 +4303,16 @@ func TestClientSendOffer(t *testing.T) { func TestClientUnshareScreen(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() mcu, err := NewTestMCU() - if err != nil { - t.Fatal(err) - } else if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(mcu.Start(ctx)) defer mcu.Stop() hub.SetMcu(mcu) @@ -5437,33 +4320,23 @@ func TestClientUnshareScreen(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ @@ -5473,20 +4346,13 @@ func TestClientUnshareScreen(t *testing.T) { Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioOnly, }, - }); err != nil { - t.Fatal(err) - } + })) - if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioOnly); err != nil { - t.Fatal(err) - } + require.NoError(client1.RunUntilAnswer(ctx, MockSdpAnswerAudioOnly)) publisher := mcu.GetPublisher(hello1.Hello.SessionId) - if publisher == nil { - t.Fatalf("No publisher for %s found", hello1.Hello.SessionId) - } else if publisher.isClosed() { - t.Fatalf("Publisher %s should not be closed", hello1.Hello.SessionId) - } + require.NotNil(publisher, "No publisher for %s found", hello1.Hello.SessionId) + require.False(publisher.isClosed(), "Publisher %s should not be closed", hello1.Hello.SessionId) old := cleanupScreenPublisherDelay cleanupScreenPublisherDelay = time.Millisecond @@ -5494,22 +4360,18 @@ func TestClientUnshareScreen(t *testing.T) { cleanupScreenPublisherDelay = old }() - if err := client1.SendMessage(MessageClientMessageRecipient{ + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, }, MessageClientMessageData{ Type: "unshareScreen", Sid: "54321", RoomType: "screen", - }); err != nil { - t.Fatal(err) - } + })) time.Sleep(10 * time.Millisecond) - if !publisher.isClosed() { - t.Fatalf("Publisher %s should be closed", hello1.Hello.SessionId) - } + require.True(publisher.isClosed(), "Publisher %s should be closed", hello1.Hello.SessionId) } func TestVirtualClientSessions(t *testing.T) { @@ -5517,6 +4379,8 @@ func TestVirtualClientSessions(t *testing.T) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -5536,89 +4400,58 @@ func TestVirtualClientSessions(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId)) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "test-room" - if _, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } + _, err = client1.JoinRoom(ctx, roomId) + require.NoError(err) - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHelloInternal(); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHelloInternal()) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) - if session2 == nil { - t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) - } + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) - if _, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } + _, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) - if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello2.Hello)) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if msg, err := checkMessageParticipantsInCall(msg); err != nil { - t.Error(err) - } else if len(msg.Users) != 1 { - t.Errorf("Expected one user, got %+v", msg) - } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { - t.Errorf("Expected internal flag, got %+v", msg) - } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) - } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != 3 { - t.Errorf("Expected inCall flag 3, got %+v", msg) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + if msg, err := checkMessageParticipantsInCall(msg); assert.NoError(err) { + if assert.Len(msg.Users, 1) { + assert.Equal(true, msg.Users[0]["internal"], "%+v", msg) + assert.Equal(hello2.Hello.SessionId, msg.Users[0]["sessionId"], "%+v", msg) + assert.EqualValues(3, msg.Users[0]["inCall"], "%+v", msg) + } + } } _, unexpected, err := client2.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello) - if err != nil { - t.Error(err) - } + assert.NoError(err) if len(unexpected) == 0 { - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else { + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { unexpected = append(unexpected, msg) } } - if len(unexpected) != 1 { - t.Fatalf("expected one message, got %+v", unexpected) - } - - if msg, err := checkMessageParticipantsInCall(unexpected[0]); err != nil { - t.Error(err) - } else if len(msg.Users) != 1 { - t.Errorf("Expected one user, got %+v", msg) - } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { - t.Errorf("Expected internal flag, got %+v", msg) - } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) - } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio { - t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg) + require.Len(unexpected, 1) + if msg, err := checkMessageParticipantsInCall(unexpected[0]); assert.NoError(err) { + if assert.Len(msg.Users, 1) { + assert.Equal(true, msg.Users[0]["internal"]) + assert.Equal(hello2.Hello.SessionId, msg.Users[0]["sessionId"]) + assert.EqualValues(FlagInCall|FlagWithAudio, msg.Users[0]["inCall"]) + } } calledCtx, calledCancel := context.WithTimeout(ctx, time.Second) @@ -5629,30 +4462,23 @@ func TestVirtualClientSessions(t *testing.T) { setSessionRequestHandler(t, func(request *BackendClientSessionRequest) { defer calledCancel() - if request.Action != "add" { - t.Errorf("Expected action add, got %+v", request) - } else if request.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, request) - } else if request.SessionId == generatedSessionId { - t.Errorf("Expected generated session id %s, got %+v", generatedSessionId, request) - } else if request.UserId != virtualUserId { - t.Errorf("Expected session id %s, got %+v", virtualUserId, request) - } + assert.Equal("add", request.Action, "%+v", request) + assert.Equal(roomId, request.RoomId, "%+v", request) + assert.NotEqual(generatedSessionId, request.SessionId, "%+v", request) + assert.Equal(virtualUserId, request.UserId, "%+v", request) }) - if err := client2.SendInternalAddSession(&AddSessionInternalClientMessage{ + require.NoError(client2.SendInternalAddSession(&AddSessionInternalClientMessage{ CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ SessionId: virtualSessionId, RoomId: roomId, }, UserId: virtualUserId, Flags: FLAG_MUTED_SPEAKING, - }); err != nil { - t.Fatal(err) - } + })) <-calledCtx.Done() - if err := calledCtx.Err(); err != nil && !errors.Is(err, context.Canceled) { - t.Fatal(err) + if err := calledCtx.Err(); err != nil { + require.ErrorIs(err, context.Canceled) } virtualSessions := session2.GetVirtualSessions() @@ -5662,131 +4488,92 @@ func TestVirtualClientSessions(t *testing.T) { } virtualSession := virtualSessions[0] - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if err := client1.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId); err != nil { - t.Error(err) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client1.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId)) } - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if msg, err := checkMessageParticipantsInCall(msg); err != nil { - t.Error(err) - } else if len(msg.Users) != 2 { - t.Errorf("Expected two users, got %+v", msg) - } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { - t.Errorf("Expected internal flag, got %+v", msg) - } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) - } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio { - t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg) - } else if v, ok := msg.Users[1]["virtual"].(bool); !ok || !v { - t.Errorf("Expected virtual flag, got %+v", msg) - } else if v, ok := msg.Users[1]["sessionId"].(string); !ok || v != virtualSession.PublicId() { - t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) - } else if v, ok := msg.Users[1]["inCall"].(float64); !ok || v != FlagInCall|FlagWithPhone { - t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithPhone, msg) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + if msg, err := checkMessageParticipantsInCall(msg); assert.NoError(err) { + if assert.Len(msg.Users, 2) { + assert.Equal(true, msg.Users[0]["internal"], "%+v", msg) + assert.Equal(hello2.Hello.SessionId, msg.Users[0]["sessionId"], "%+v", msg) + assert.EqualValues(FlagInCall|FlagWithAudio, msg.Users[0]["inCall"], "%+v", msg) + + assert.Equal(true, msg.Users[1]["virtual"], "%+v", msg) + assert.Equal(virtualSession.PublicId(), msg.Users[1]["sessionId"], "%+v", msg) + assert.EqualValues(FlagInCall|FlagWithPhone, msg.Users[1]["inCall"], "%+v", msg) + } + } } - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if flags, err := checkMessageParticipantFlags(msg); err != nil { - t.Error(err) - } else if flags.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, msg) - } else if flags.SessionId != virtualSession.PublicId() { - t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) - } else if flags.Flags != FLAG_MUTED_SPEAKING { - t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, msg) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + if flags, err := checkMessageParticipantFlags(msg); assert.NoError(err) { + assert.Equal(roomId, flags.RoomId) + assert.Equal(virtualSession.PublicId(), flags.SessionId) + assert.EqualValues(FLAG_MUTED_SPEAKING, flags.Flags) + } } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if err := client2.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId); err != nil { - t.Error(err) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client2.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId)) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if msg, err := checkMessageParticipantsInCall(msg); err != nil { - t.Error(err) - } else if len(msg.Users) != 2 { - t.Errorf("Expected two users, got %+v", msg) - } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { - t.Errorf("Expected internal flag, got %+v", msg) - } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { - t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) - } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio { - t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg) - } else if v, ok := msg.Users[1]["virtual"].(bool); !ok || !v { - t.Errorf("Expected virtual flag, got %+v", msg) - } else if v, ok := msg.Users[1]["sessionId"].(string); !ok || v != virtualSession.PublicId() { - t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) - } else if v, ok := msg.Users[1]["inCall"].(float64); !ok || v != FlagInCall|FlagWithPhone { - t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithPhone, msg) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + if msg, err := checkMessageParticipantsInCall(msg); assert.NoError(err) { + if assert.Len(msg.Users, 2) { + assert.Equal(true, msg.Users[0]["internal"], "%+v", msg) + assert.Equal(hello2.Hello.SessionId, msg.Users[0]["sessionId"], "%+v", msg) + assert.EqualValues(FlagInCall|FlagWithAudio, msg.Users[0]["inCall"], "%+v", msg) + + assert.Equal(true, msg.Users[1]["virtual"], "%+v", msg) + assert.Equal(virtualSession.PublicId(), msg.Users[1]["sessionId"], "%+v", msg) + assert.EqualValues(FlagInCall|FlagWithPhone, msg.Users[1]["inCall"], "%+v", msg) + } + } } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if flags, err := checkMessageParticipantFlags(msg); err != nil { - t.Error(err) - } else if flags.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, msg) - } else if flags.SessionId != virtualSession.PublicId() { - t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) - } else if flags.Flags != FLAG_MUTED_SPEAKING { - t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, msg) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + if flags, err := checkMessageParticipantFlags(msg); assert.NoError(err) { + assert.Equal(roomId, flags.RoomId) + assert.Equal(virtualSession.PublicId(), flags.SessionId) + assert.EqualValues(FLAG_MUTED_SPEAKING, flags.Flags) + } } updatedFlags := uint32(0) - if err := client2.SendInternalUpdateSession(&UpdateSessionInternalClientMessage{ + require.NoError(client2.SendInternalUpdateSession(&UpdateSessionInternalClientMessage{ CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ SessionId: virtualSessionId, RoomId: roomId, }, Flags: &updatedFlags, - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if flags, err := checkMessageParticipantFlags(msg); err != nil { - t.Error(err) - } else if flags.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, msg) - } else if flags.SessionId != virtualSession.PublicId() { - t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) - } else if flags.Flags != 0 { - t.Errorf("Expected flags %d, got %+v", 0, msg) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + if flags, err := checkMessageParticipantFlags(msg); assert.NoError(err) { + assert.Equal(roomId, flags.RoomId) + assert.Equal(virtualSession.PublicId(), flags.SessionId) + assert.EqualValues(0, flags.Flags) + } } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if flags, err := checkMessageParticipantFlags(msg); err != nil { - t.Error(err) - } else if flags.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, msg) - } else if flags.SessionId != virtualSession.PublicId() { - t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) - } else if flags.Flags != 0 { - t.Errorf("Expected flags %d, got %+v", 0, msg) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + if flags, err := checkMessageParticipantFlags(msg); assert.NoError(err) { + assert.Equal(roomId, flags.RoomId) + assert.Equal(virtualSession.PublicId(), flags.SessionId) + assert.EqualValues(0, flags.Flags) + } } calledCtx, calledCancel = context.WithTimeout(ctx, time.Second) setSessionRequestHandler(t, func(request *BackendClientSessionRequest) { defer calledCancel() - if request.Action != "remove" { - t.Errorf("Expected action remove, got %+v", request) - } else if request.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, request) - } else if request.SessionId == generatedSessionId { - t.Errorf("Expected generated session id %s, got %+v", generatedSessionId, request) - } else if request.UserId != virtualUserId { - t.Errorf("Expected user id %s, got %+v", virtualUserId, request) - } + assert.Equal("remove", request.Action, "%+v", request) + assert.Equal(roomId, request.RoomId, "%+v", request) + assert.NotEqual(generatedSessionId, request.SessionId, "%+v", request) + assert.Equal(virtualUserId, request.UserId, "%+v", request) }) // Messages to virtual sessions are sent to the associated client session. @@ -5801,52 +4588,39 @@ func TestVirtualClientSessions(t *testing.T) { var payload string var sender *MessageServerMessageSender var recipient *MessageClientMessageRecipient - if err := checkReceiveClientMessageWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); err != nil { - t.Error(err) - } else if recipient.SessionId != virtualSessionId { - t.Errorf("Expected session id %s, got %+v", virtualSessionId, recipient) - } else if payload != data { - t.Errorf("Expected payload %s, got %s", data, payload) + if err := checkReceiveClientMessageWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); assert.NoError(err) { + assert.Equal(virtualSessionId, recipient.SessionId, "%+v", recipient) + assert.Equal(data, payload) } data = "control-to-virtual" client1.SendControl(virtualRecipient, data) // nolint - if err := checkReceiveClientControlWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); err != nil { - t.Error(err) - } else if recipient.SessionId != virtualSessionId { - t.Errorf("Expected session id %s, got %+v", virtualSessionId, recipient) - } else if payload != data { - t.Errorf("Expected payload %s, got %s", data, payload) + if err := checkReceiveClientControlWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); assert.NoError(err) { + assert.Equal(virtualSessionId, recipient.SessionId, "%+v", recipient) + assert.Equal(data, payload) } - if err := client2.SendInternalRemoveSession(&RemoveSessionInternalClientMessage{ + require.NoError(client2.SendInternalRemoveSession(&RemoveSessionInternalClientMessage{ CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ SessionId: virtualSessionId, RoomId: roomId, }, UserId: virtualUserId, - }); err != nil { - t.Fatal(err) - } + })) <-calledCtx.Done() if err := calledCtx.Err(); err != nil && !errors.Is(err, context.Canceled) { - t.Fatal(err) + require.NoError(err) } - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if err := client1.checkMessageRoomLeaveSession(msg, virtualSession.PublicId()); err != nil { - t.Error(err) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client1.checkMessageRoomLeaveSession(msg, virtualSession.PublicId())) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if err := client2.checkMessageRoomLeaveSession(msg, virtualSession.PublicId()); err != nil { - t.Error(err) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client1.checkMessageRoomLeaveSession(msg, virtualSession.PublicId())) } - }) } } @@ -5856,6 +4630,8 @@ func DoTestSwitchToOne(t *testing.T, details map[string]interface{}) { for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -5872,63 +4648,45 @@ func DoTestSwitchToOne(t *testing.T, details map[string]interface{}) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomSessionId1 := "roomsession1" roomId1 := "test-room" - if room, err := client1.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId1); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId1 { - t.Fatalf("Expected room %s, got %s", roomId1, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId1) + require.NoError(err) + require.Equal(roomId1, roomMsg.Room.RoomId) roomSessionId2 := "roomsession2" - if room, err := client2.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId2); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId1 { - t.Fatalf("Expected room %s, got %s", roomId1, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId2) + require.NoError(err) + require.Equal(roomId1, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } - if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) + assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) roomId2 := "test-room-2" var sessions json.RawMessage if details != nil { - if sessions, err = json.Marshal(map[string]interface{}{ + sessions, err = json.Marshal(map[string]interface{}{ roomSessionId1: details, - }); err != nil { - t.Fatal(err) - } + }) + require.NoError(err) } else { - if sessions, err = json.Marshal([]string{ + sessions, err = json.Marshal([]string{ roomSessionId1, - }); err != nil { - t.Fatal(err) - } + }) + require.NoError(err) } // Notify first client to switch to different room. @@ -5941,40 +4699,30 @@ func DoTestSwitchToOne(t *testing.T, details map[string]interface{}) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server2.URL+"/api/v1/room/"+roomId1, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) var detailsData json.RawMessage if details != nil { - if detailsData, err = json.Marshal(details); err != nil { - t.Fatal(err) - } - } - if _, err := client1.RunUntilSwitchTo(ctx, roomId2, detailsData); err != nil { - t.Error(err) + detailsData, err = json.Marshal(details) + require.NoError(err) } + _, err = client1.RunUntilSwitchTo(ctx, roomId2, detailsData) + assert.NoError(err) // The other client will not receive a message. ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel2() - if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message, got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) } }) } @@ -5995,6 +4743,8 @@ func DoTestSwitchToMultiple(t *testing.T, details1 map[string]interface{}, detai for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + require := require.New(t) + assert := assert.New(t) var hub1 *Hub var hub2 *Hub var server1 *httptest.Server @@ -6011,65 +4761,47 @@ func DoTestSwitchToMultiple(t *testing.T, details1 map[string]interface{}, detai client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomSessionId1 := "roomsession1" roomId1 := "test-room" - if room, err := client1.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId1); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId1 { - t.Fatalf("Expected room %s, got %s", roomId1, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId1) + require.NoError(err) + require.Equal(roomId1, roomMsg.Room.RoomId) roomSessionId2 := "roomsession2" - if room, err := client2.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId2); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId1 { - t.Fatalf("Expected room %s, got %s", roomId1, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoomWithRoomSession(ctx, roomId1, roomSessionId2) + require.NoError(err) + require.Equal(roomId1, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } - if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) + assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) roomId2 := "test-room-2" var sessions json.RawMessage if details1 != nil || details2 != nil { - if sessions, err = json.Marshal(map[string]interface{}{ + sessions, err = json.Marshal(map[string]interface{}{ roomSessionId1: details1, roomSessionId2: details2, - }); err != nil { - t.Fatal(err) - } + }) + require.NoError(err) } else { - if sessions, err = json.Marshal([]string{ + sessions, err = json.Marshal([]string{ roomSessionId1, roomSessionId2, - }); err != nil { - t.Fatal(err) - } + }) + require.NoError(err) } msg := &BackendServerRoomRequest{ @@ -6081,41 +4813,29 @@ func DoTestSwitchToMultiple(t *testing.T, details1 map[string]interface{}, detai } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server2.URL+"/api/v1/room/"+roomId1, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) var detailsData1 json.RawMessage if details1 != nil { - if detailsData1, err = json.Marshal(details1); err != nil { - t.Fatal(err) - } - } - if _, err := client1.RunUntilSwitchTo(ctx, roomId2, detailsData1); err != nil { - t.Error(err) + detailsData1, err = json.Marshal(details1) + require.NoError(err) } + _, err = client1.RunUntilSwitchTo(ctx, roomId2, detailsData1) + assert.NoError(err) var detailsData2 json.RawMessage if details2 != nil { - if detailsData2, err = json.Marshal(details2); err != nil { - t.Fatal(err) - } - } - if _, err := client2.RunUntilSwitchTo(ctx, roomId2, detailsData2); err != nil { - t.Error(err) + detailsData2, err = json.Marshal(details2) + require.NoError(err) } + _, err = client2.RunUntilSwitchTo(ctx, roomId2, detailsData2) + assert.NoError(err) }) } } @@ -6141,6 +4861,7 @@ func TestSwitchToMultipleMixed(t *testing.T) { func TestGeoipOverrides(t *testing.T) { t.Parallel() CatchLogForTest(t) + assert := assert.New(t) country1 := "DE" country2 := "IT" country3 := "site1" @@ -6156,66 +4877,43 @@ func TestGeoipOverrides(t *testing.T) { return conf, err }) - if country := hub.OnLookupCountry(&Client{addr: "127.0.0.1"}); country != loopback { - t.Errorf("expected country %s, got %s", loopback, country) - } - - if country := hub.OnLookupCountry(&Client{addr: "8.8.8.8"}); country != unknownCountry { - t.Errorf("expected country %s, got %s", unknownCountry, country) - } - - if country := hub.OnLookupCountry(&Client{addr: "10.1.1.2"}); country != country1 { - t.Errorf("expected country %s, got %s", country1, country) - } - - if country := hub.OnLookupCountry(&Client{addr: "10.2.1.2"}); country != country2 { - t.Errorf("expected country %s, got %s", country2, country) - } - - if country := hub.OnLookupCountry(&Client{addr: "192.168.10.20"}); country != strings.ToUpper(country3) { - t.Errorf("expected country %s, got %s", strings.ToUpper(country3), country) - } + assert.Equal(loopback, hub.OnLookupCountry(&Client{addr: "127.0.0.1"})) + assert.Equal(unknownCountry, hub.OnLookupCountry(&Client{addr: "8.8.8.8"})) + assert.Equal(country1, hub.OnLookupCountry(&Client{addr: "10.1.1.2"})) + assert.Equal(country2, hub.OnLookupCountry(&Client{addr: "10.2.1.2"})) + assert.Equal(strings.ToUpper(country3), hub.OnLookupCountry(&Client{addr: "192.168.10.20"})) } func TestDialoutStatus(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) internalClient := NewTestClient(t, server, hub) defer internalClient.CloseWithBye() - if err := internalClient.SendHelloInternalWithFeatures([]string{"start-dialout"}); err != nil { - t.Fatal(err) - } + require.NoError(internalClient.SendHelloInternalWithFeatures([]string{"start-dialout"})) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() _, err := internalClient.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) roomId := "12345" client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if _, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } + _, err = client.JoinRoom(ctx, roomId) + require.NoError(err) - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) callId := "call-123" @@ -6224,19 +4922,18 @@ func TestDialoutStatus(t *testing.T) { defer close(stopped) msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) + if !assert.NoError(err) { return } - if msg.Type != "internal" || msg.Internal.Type != "dialout" { - t.Errorf("expected internal dialout message, got %+v", msg) + if !assert.Equal("internal", msg.Type, "%+v", msg) || + !assert.NotNil(msg.Internal, "%+v", msg) || + !assert.Equal("dialout", msg.Internal.Type, "%+v", msg) || + !assert.NotNil(msg.Internal.Dialout, "%+v", msg) { return } - if msg.Internal.Dialout.RoomId != roomId { - t.Errorf("expected room id %s, got %+v", roomId, msg) - } + assert.Equal(roomId, msg.Internal.Dialout.RoomId) response := &ClientMessage{ Id: msg.Id, @@ -6253,9 +4950,7 @@ func TestDialoutStatus(t *testing.T) { }, }, } - if err := client.WriteJSON(response); err != nil { - t.Error(err) - } + assert.NoError(client.WriteJSON(response)) }(internalClient) defer func() { @@ -6270,73 +4965,48 @@ func TestDialoutStatus(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusOK { - t.Fatalf("Expected error %d, got %s: %s", http.StatusOK, res.Status, string(body)) - } + assert.NoError(err) + require.Equal(http.StatusOK, res.StatusCode, "Expected success, got %s", string(body)) var response BackendServerRoomResponse - if err := json.Unmarshal(body, &response); err != nil { - t.Fatal(err) - } - - if response.Type != "dialout" || response.Dialout == nil { - t.Fatalf("expected type dialout, got %s", string(body)) - } - if response.Dialout.Error != nil { - t.Fatalf("expected dialout success, got %s", string(body)) - } - if response.Dialout.CallId != callId { - t.Errorf("expected call id %s, got %s", callId, string(body)) + if assert.NoError(json.Unmarshal(body, &response)) { + assert.Equal("dialout", response.Type) + if assert.NotNil(response.Dialout) { + assert.Nil(response.Dialout.Error, "expected dialout success, got %s", string(body)) + assert.Equal(callId, response.Dialout.CallId) + } } key := "callstatus_" + callId - if msg, err := client.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, key, map[string]interface{}{ + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageTransientSet(msg, key, map[string]interface{}{ "callid": callId, "status": "accepted", - }, nil); err != nil { - t.Error(err) - } + }, nil)) } - if err := internalClient.SendInternalDialout(&DialoutInternalClientMessage{ + require.NoError(internalClient.SendInternalDialout(&DialoutInternalClientMessage{ RoomId: roomId, Type: "status", Status: &DialoutStatusInternalClientMessage{ CallId: callId, Status: "ringing", }, - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, key, map[string]interface{}{ + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageTransientSet(msg, key, map[string]interface{}{ "callid": callId, "status": "ringing", - }, - map[string]interface{}{ - "callid": callId, - "status": "accepted", - }); err != nil { - t.Error(err) - } + }, map[string]interface{}{ + "callid": callId, + "status": "accepted", + })) } old := removeCallStatusTTL @@ -6346,7 +5016,7 @@ func TestDialoutStatus(t *testing.T) { removeCallStatusTTL = 500 * time.Millisecond clearedCause := "cleared-call" - if err := internalClient.SendInternalDialout(&DialoutInternalClientMessage{ + require.NoError(internalClient.SendInternalDialout(&DialoutInternalClientMessage{ RoomId: roomId, Type: "status", Status: &DialoutStatusInternalClientMessage{ @@ -6354,39 +5024,28 @@ func TestDialoutStatus(t *testing.T) { Status: "cleared", Cause: clearedCause, }, - }); err != nil { - t.Fatal(err) - } + })) - if msg, err := client.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, key, map[string]interface{}{ + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageTransientSet(msg, key, map[string]interface{}{ "callid": callId, "status": "cleared", "cause": clearedCause, - }, - map[string]interface{}{ - "callid": callId, - "status": "ringing", - }); err != nil { - t.Error(err) - } + }, map[string]interface{}{ + "callid": callId, + "status": "ringing", + })) } ctx2, cancel := context.WithTimeout(ctx, removeCallStatusTTL*2) defer cancel() - if msg, err := client.RunUntilMessage(ctx2); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientRemove(msg, key, map[string]interface{}{ + if msg, err := client.RunUntilMessage(ctx2); assert.NoError(err) { + assert.NoError(checkMessageTransientRemove(msg, key, map[string]interface{}{ "callid": callId, "status": "cleared", "cause": clearedCause, - }); err != nil { - t.Error(err) - } + })) } } @@ -6402,6 +5061,8 @@ func TestGracefulShutdownInitial(t *testing.T) { func TestGracefulShutdownOnBye(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -6410,18 +5071,15 @@ func TestGracefulShutdownOnBye(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) - if _, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } + _, err := client.RunUntilHello(ctx) + require.NoError(err) hub.ScheduleShutdown() select { case <-hub.ShutdownChannel(): - t.Error("should not have shutdown") + assert.Fail("should not have shutdown") case <-time.After(100 * time.Millisecond): } @@ -6430,13 +5088,15 @@ func TestGracefulShutdownOnBye(t *testing.T) { select { case <-hub.ShutdownChannel(): case <-time.After(100 * time.Millisecond): - t.Error("should have shutdown") + assert.Fail("should have shutdown") } } func TestGracefulShutdownOnExpiration(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -6445,25 +5105,22 @@ func TestGracefulShutdownOnExpiration(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) - if _, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } + _, err := client.RunUntilHello(ctx) + require.NoError(err) hub.ScheduleShutdown() select { case <-hub.ShutdownChannel(): - t.Error("should not have shutdown") + assert.Fail("should not have shutdown") case <-time.After(100 * time.Millisecond): } client.Close() select { case <-hub.ShutdownChannel(): - t.Error("should not have shutdown") + assert.Fail("should not have shutdown") case <-time.After(100 * time.Millisecond): } @@ -6472,6 +5129,6 @@ func TestGracefulShutdownOnExpiration(t *testing.T) { select { case <-hub.ShutdownChannel(): case <-time.After(100 * time.Millisecond): - t.Error("should have shutdown") + assert.Fail("should have shutdown") } } diff --git a/lru_test.go b/lru_test.go index 77db25e0..98fbb664 100644 --- a/lru_test.go +++ b/lru_test.go @@ -24,46 +24,36 @@ package signaling import ( "fmt" "testing" + + "github.com/stretchr/testify/assert" ) func TestLruUnbound(t *testing.T) { + assert := assert.New(t) lru := NewLruCache(0) count := 10 for i := 0; i < count; i++ { key := fmt.Sprintf("%d", i) lru.Set(key, i) } - if lru.Len() != count { - t.Errorf("Expected %d entries, got %d", count, lru.Len()) - } + assert.Equal(count, lru.Len()) for i := 0; i < count; i++ { key := fmt.Sprintf("%d", i) - value := lru.Get(key) - if value == nil { - t.Errorf("No value found for %s", key) - continue - } else if value.(int) != i { - t.Errorf("Expected value to be %d, got %d", value.(int), i) + if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) { + assert.EqualValues(i, value) } } // The first key ("0") is now the oldest. lru.RemoveOldest() - if lru.Len() != count-1 { - t.Errorf("Expected %d entries after RemoveOldest, got %d", count-1, lru.Len()) - } + assert.Equal(count-1, lru.Len()) for i := 0; i < count; i++ { key := fmt.Sprintf("%d", i) value := lru.Get(key) if i == 0 { - if value != nil { - t.Errorf("The value for key %s should have been removed", key) - } - continue - } else if value == nil { - t.Errorf("No value found for %s", key) + assert.Nil(value, "The value for key %s should have been removed", key) continue - } else if value.(int) != i { - t.Errorf("Expected value to be %d, got %d", value.(int), i) + } else if assert.NotNil(value, "No value found for %s", key) { + assert.EqualValues(i, value) } } @@ -74,66 +64,47 @@ func TestLruUnbound(t *testing.T) { key := fmt.Sprintf("%d", i) lru.Set(key, i) } - if lru.Len() != count-1 { - t.Errorf("Expected %d entries, got %d", count-1, lru.Len()) - } + assert.Equal(count-1, lru.Len()) // NOTE: The same ordering as the Set calls above. for i := count - 1; i >= 1; i-- { key := fmt.Sprintf("%d", i) - value := lru.Get(key) - if value == nil { - t.Errorf("No value found for %s", key) - continue - } else if value.(int) != i { - t.Errorf("Expected value to be %d, got %d", value.(int), i) + if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) { + assert.EqualValues(i, value) } } // The last key ("9") is now the oldest. lru.RemoveOldest() - if lru.Len() != count-2 { - t.Errorf("Expected %d entries after RemoveOldest, got %d", count-2, lru.Len()) - } + assert.Equal(count-2, lru.Len()) for i := 0; i < count; i++ { key := fmt.Sprintf("%d", i) value := lru.Get(key) if i == 0 || i == count-1 { - if value != nil { - t.Errorf("The value for key %s should have been removed", key) - } + assert.Nil(value, "The value for key %s should have been removed", key) continue - } else if value == nil { - t.Errorf("No value found for %s", key) - continue - } else if value.(int) != i { - t.Errorf("Expected value to be %d, got %d", value.(int), i) + } else if assert.NotNil(value, "No value found for %s", key) { + assert.EqualValues(i, value) } } // Remove an arbitrary key from the cache key := fmt.Sprintf("%d", count/2) lru.Remove(key) - if lru.Len() != count-3 { - t.Errorf("Expected %d entries after RemoveOldest, got %d", count-3, lru.Len()) - } + assert.Equal(count-3, lru.Len()) for i := 0; i < count; i++ { key := fmt.Sprintf("%d", i) value := lru.Get(key) if i == 0 || i == count-1 || i == count/2 { - if value != nil { - t.Errorf("The value for key %s should have been removed", key) - } + assert.Nil(value, "The value for key %s should have been removed", key) continue - } else if value == nil { - t.Errorf("No value found for %s", key) - continue - } else if value.(int) != i { - t.Errorf("Expected value to be %d, got %d", value.(int), i) + } else if assert.NotNil(value, "No value found for %s", key) { + assert.EqualValues(i, value) } } } func TestLruBound(t *testing.T) { + assert := assert.New(t) size := 2 lru := NewLruCache(size) count := 10 @@ -141,23 +112,16 @@ func TestLruBound(t *testing.T) { key := fmt.Sprintf("%d", i) lru.Set(key, i) } - if lru.Len() != size { - t.Errorf("Expected %d entries, got %d", size, lru.Len()) - } + assert.Equal(size, lru.Len()) // Only the last "size" entries have been stored. for i := 0; i < count; i++ { key := fmt.Sprintf("%d", i) value := lru.Get(key) if i < count-size { - if value != nil { - t.Errorf("The value for key %s should have been removed", key) - } - continue - } else if value == nil { - t.Errorf("No value found for %s", key) + assert.Nil(value, "The value for key %s should have been removed", key) continue - } else if value.(int) != i { - t.Errorf("Expected value to be %d, got %d", value.(int), i) + } else if assert.NotNil(value, "No value found for %s", key) { + assert.EqualValues(i, value) } } } diff --git a/mcu_janus_publisher_test.go b/mcu_janus_publisher_test.go index dd81e795..ab7d96a6 100644 --- a/mcu_janus_publisher_test.go +++ b/mcu_janus_publisher_test.go @@ -23,9 +23,12 @@ package signaling import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestGetFmtpValueH264(t *testing.T) { + assert := assert.New(t) testcases := []struct { fmtp string profile string @@ -51,16 +54,17 @@ func TestGetFmtpValueH264(t *testing.T) { for _, tc := range testcases { value, found := getFmtpValue(tc.fmtp, "profile-level-id") if !found && tc.profile != "" { - t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp) + assert.Fail("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp) } else if found && tc.profile == "" { - t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value) + assert.Fail("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value) } else if found && tc.profile != value { - t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value) + assert.Fail("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value) } } } func TestGetFmtpValueVP9(t *testing.T) { + assert := assert.New(t) testcases := []struct { fmtp string profile string @@ -82,11 +86,11 @@ func TestGetFmtpValueVP9(t *testing.T) { for _, tc := range testcases { value, found := getFmtpValue(tc.fmtp, "profile-id") if !found && tc.profile != "" { - t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp) + assert.Fail("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp) } else if found && tc.profile == "" { - t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value) + assert.Fail("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value) } else if found && tc.profile != value { - t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value) + assert.Fail("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value) } } } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index 39f12a9c..cfdb865f 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -41,6 +41,8 @@ import ( "github.com/dlintw/goconf" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/server/v3/embed" ) @@ -105,7 +107,7 @@ func Test_sortConnectionsForCountry(t *testing.T) { sorted := sortConnectionsForCountry(test[0], country, nil) for idx, conn := range sorted { if test[1][idx] != conn { - t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + assert.Fail(t, "Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) } } }) @@ -179,7 +181,7 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) { sorted := sortConnectionsForCountry(test[0], country, continentMap) for idx, conn := range sorted { if test[1][idx] != conn { - t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + assert.Fail(t, "Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) } } }) @@ -342,7 +344,7 @@ func (c *testProxyServerClient) handleSendMessageError(fmt string, msg *ProxySer c.t.Helper() if !errors.Is(err, websocket.ErrCloseSent) || msg.Type != "event" || msg.Event.Type != "update-load" { - c.t.Errorf(fmt, msg, err) + assert.Fail(c.t, fmt, msg, err) } } @@ -385,6 +387,7 @@ func (c *testProxyServerClient) run() { c.ws = nil }() c.processMessage = c.processHello + assert := assert.New(c.t) for { c.mu.Lock() ws := c.ws @@ -396,36 +399,31 @@ func (c *testProxyServerClient) run() { msgType, reader, err := ws.NextReader() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { - c.t.Error(err) + assert.NoError(err) } return } body, err := io.ReadAll(reader) - if err != nil { - c.t.Error(err) + if !assert.NoError(err) { continue } - if msgType != websocket.TextMessage { - c.t.Errorf("unexpected message type %q (%s)", msgType, string(body)) + if !assert.Equal(websocket.TextMessage, msgType, "unexpected message type for %s", string(body)) { continue } var msg ProxyClientMessage - if err := json.Unmarshal(body, &msg); err != nil { - c.t.Errorf("could not decode message %s: %s", string(body), err) + if err := json.Unmarshal(body, &msg); !assert.NoError(err, "could not decode message %s", string(body)) { continue } - if err := msg.CheckValid(); err != nil { - c.t.Errorf("invalid message %s: %s", string(body), err) + if err := msg.CheckValid(); !assert.NoError(err, "invalid message %s", string(body)) { continue } response, err := c.processMessage(&msg) - if err != nil { - c.t.Error(err) + if !assert.NoError(err) { continue } @@ -605,8 +603,7 @@ func (h *TestProxyServerHandler) removeClient(client *testProxyServerClient) { func (h *TestProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ws, err := h.upgrader.Upgrade(w, r, nil) - if err != nil { - h.t.Error(err) + if !assert.NoError(h.t, err) { return } @@ -658,15 +655,14 @@ type proxyTestOptions struct { func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuProxy { t.Helper() + require := require.New(t) if options.etcd == nil { options.etcd = NewEtcdForTest(t) } grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, options.etcd) tokenKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(err) dir := t.TempDir() privkeyFile := path.Join(dir, "privkey.pem") pubkeyFile := path.Join(dir, "pubkey.pem") @@ -696,19 +692,13 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP etcdConfig.AddOption("etcd", "loglevel", "error") etcdClient, err := NewEtcdClient(etcdConfig, "") - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { - if err := etcdClient.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, etcdClient.Close()) }) mcu, err := NewMcuProxy(cfg, etcdClient, grpcClients, dnsMonitor) - if err != nil { - t.Fatal(err) - } + require.NoError(err) t.Cleanup(func() { mcu.Stop() }) @@ -716,20 +706,14 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if err := mcu.Start(ctx); err != nil { - t.Fatal(err) - } + require.NoError(mcu.Start(ctx)) proxy := mcu.(*mcuProxy) - if err := proxy.WaitForConnections(ctx); err != nil { - t.Fatal(err) - } + require.NoError(proxy.WaitForConnections(ctx)) for len(waitingMap) > 0 { - if err := ctx.Err(); err != nil { - t.Fatal(err) - } + require.NoError(ctx.Err()) for u := range waitingMap { proxy.connectionsMu.RLock() @@ -782,9 +766,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) { } pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) @@ -795,9 +777,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) { country: "DE", } sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) } @@ -829,8 +809,7 @@ func Test_ProxyWaitForPublisher(t *testing.T) { go func() { defer close(done) sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Error(err) + if !assert.NoError(t, err) { return } @@ -841,14 +820,12 @@ func Test_ProxyWaitForPublisher(t *testing.T) { time.Sleep(100 * time.Millisecond) pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) select { case <-done: case <-ctx.Done(): - t.Error(ctx.Err()) + assert.NoError(t, ctx.Err()) } defer pub.Close(context.Background()) } @@ -875,9 +852,7 @@ func Test_ProxyPublisherBandwidth(t *testing.T) { country: "DE", } pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub1.Close(context.Background()) @@ -914,15 +889,11 @@ func Test_ProxyPublisherBandwidth(t *testing.T) { country: "DE", } pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub2.Close(context.Background()) - if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl { - t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl) - } + assert.NotEqual(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl) } func Test_ProxyPublisherBandwidthOverload(t *testing.T) { @@ -947,9 +918,7 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) { country: "DE", } pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub1.Close(context.Background()) @@ -989,15 +958,11 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) { country: "DE", } pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub2.Close(context.Background()) - if pub1.(*mcuProxyPublisher).conn.rawUrl != pub2.(*mcuProxyPublisher).conn.rawUrl { - t.Errorf("servers should be the same, got %s / %s", pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl) } func Test_ProxyPublisherLoad(t *testing.T) { @@ -1022,9 +987,7 @@ func Test_ProxyPublisherLoad(t *testing.T) { country: "DE", } pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub1.Close(context.Background()) @@ -1041,15 +1004,11 @@ func Test_ProxyPublisherLoad(t *testing.T) { country: "DE", } pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub2.Close(context.Background()) - if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl { - t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl) - } + assert.NotEqual(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl) } func Test_ProxyPublisherCountry(t *testing.T) { @@ -1074,15 +1033,11 @@ func Test_ProxyPublisherCountry(t *testing.T) { country: "DE", } pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pubDE.Close(context.Background()) - if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) pubUSId := "the-publisher-us" pubUSSid := "1234567890" @@ -1093,15 +1048,11 @@ func Test_ProxyPublisherCountry(t *testing.T) { country: "US", } pubUS, err := mcu.NewPublisher(ctx, pubUSListener, pubUSId, pubUSSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubUSInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pubUS.Close(context.Background()) - if pubUS.(*mcuProxyPublisher).conn.rawUrl != serverUS.URL { - t.Errorf("expected server %s, go %s", serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl) } func Test_ProxyPublisherContinent(t *testing.T) { @@ -1126,15 +1077,11 @@ func Test_ProxyPublisherContinent(t *testing.T) { country: "DE", } pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pubDE.Close(context.Background()) - if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) pubFRId := "the-publisher-fr" pubFRSid := "1234567890" @@ -1145,15 +1092,11 @@ func Test_ProxyPublisherContinent(t *testing.T) { country: "FR", } pubFR, err := mcu.NewPublisher(ctx, pubFRListener, pubFRId, pubFRSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubFRInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pubFR.Close(context.Background()) - if pubFR.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl) } func Test_ProxySubscriberCountry(t *testing.T) { @@ -1178,15 +1121,11 @@ func Test_ProxySubscriberCountry(t *testing.T) { country: "DE", } pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) - if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) subListener := &MockMcuListener{ publicId: "subscriber-public", @@ -1195,15 +1134,11 @@ func Test_ProxySubscriberCountry(t *testing.T) { country: "US", } sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) - if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL { - t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) - } + assert.Equal(t, serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) } func Test_ProxySubscriberContinent(t *testing.T) { @@ -1228,15 +1163,11 @@ func Test_ProxySubscriberContinent(t *testing.T) { country: "DE", } pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) - if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) subListener := &MockMcuListener{ publicId: "subscriber-public", @@ -1245,15 +1176,11 @@ func Test_ProxySubscriberContinent(t *testing.T) { country: "FR", } sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) - if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl) } func Test_ProxySubscriberBandwidth(t *testing.T) { @@ -1278,15 +1205,11 @@ func Test_ProxySubscriberBandwidth(t *testing.T) { country: "DE", } pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) - if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) serverDE.UpdateBandwidth(0, 100) @@ -1315,15 +1238,11 @@ func Test_ProxySubscriberBandwidth(t *testing.T) { country: "US", } sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) - if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL { - t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) - } + assert.Equal(t, serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) } func Test_ProxySubscriberBandwidthOverload(t *testing.T) { @@ -1348,15 +1267,11 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) { country: "DE", } pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) - if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) serverDE.UpdateBandwidth(0, 100) serverUS.UpdateBandwidth(0, 102) @@ -1386,15 +1301,11 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) { country: "US", } sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) - if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL { - t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl) - } + assert.Equal(t, serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl) } type mockGrpcServerHub struct { @@ -1490,9 +1401,7 @@ func Test_ProxyRemotePublisher(t *testing.T) { defer hub1.removeSession(session1) pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) @@ -1508,9 +1417,7 @@ func Test_ProxyRemotePublisher(t *testing.T) { country: "DE", } sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) } @@ -1580,8 +1487,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) { go func() { defer close(done) sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Error(err) + if !assert.NoError(t, err) { return } @@ -1592,9 +1498,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) { time.Sleep(100 * time.Millisecond) pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) @@ -1606,7 +1510,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) { select { case <-done: case <-ctx.Done(): - t.Error(ctx.Err()) + assert.NoError(t, ctx.Err()) } } @@ -1663,9 +1567,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) { defer hub1.removeSession(session1) pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer pub.Close(context.Background()) @@ -1677,9 +1579,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) { mcu2.connectionsMu.RLock() count := len(mcu2.connections) mcu2.connectionsMu.RUnlock() - if expected := 1; count != expected { - t.Errorf("expected %d connections, got %+v", expected, count) - } + assert.Equal(t, 1, count) subListener := &MockMcuListener{ publicId: "subscriber-public", @@ -1688,23 +1588,17 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) { country: "DE", } sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close(context.Background()) - if sub.(*mcuProxySubscriber).conn.rawUrl != server1.URL { - t.Errorf("expected server %s, go %s", server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl) - } + assert.Equal(t, server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl) // The temporary connection has been added mcu2.connectionsMu.RLock() count = len(mcu2.connections) mcu2.connectionsMu.RUnlock() - if expected := 2; count != expected { - t.Errorf("expected %d connections, got %+v", expected, count) - } + assert.Equal(t, 2, count) sub.Close(context.Background()) @@ -1713,7 +1607,7 @@ loop: for { select { case <-ctx.Done(): - t.Error(ctx.Err()) + assert.NoError(t, ctx.Err()) default: mcu2.connectionsMu.RLock() count = len(mcu2.connections) diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index ff8994c4..d6cf5dea 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -25,6 +25,9 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *testing.T) { @@ -39,7 +42,7 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t select { case <-ctx.Done(): c.mu.Lock() - t.Errorf("Error waiting for subscriptions %+v to terminate: %s", c.subscriptions, ctx.Err()) + assert.NoError(t, ctx.Err(), "Error waiting for subscriptions %+v to terminate", c.subscriptions) c.mu.Unlock() return default: @@ -50,9 +53,7 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { result, err := NewLoopbackNatsClient() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { result.Close() }) diff --git a/natsclient_test.go b/natsclient_test.go index dfe51af5..a28a0c75 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -27,6 +27,8 @@ import ( "time" "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" natsserver "github.com/nats-io/nats-server/v2/test" ) @@ -46,9 +48,7 @@ func startLocalNatsServer(t *testing.T) string { func CreateLocalNatsClientForTest(t *testing.T) NatsClient { url := startLocalNatsServer(t) result, err := NewNatsClient(url) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { result.Close() }) @@ -56,11 +56,11 @@ func CreateLocalNatsClientForTest(t *testing.T) NatsClient { } func testNatsClient_Subscribe(t *testing.T, client NatsClient) { + require := require.New(t) + assert := assert.New(t) dest := make(chan *nats.Msg) sub, err := client.Subscribe("foo", dest) - if err != nil { - t.Fatal(err) - } + require.NoError(err) ch := make(chan struct{}) var received atomic.Int32 @@ -75,9 +75,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { case <-dest: total := received.Add(1) if total == max { - err := sub.Unsubscribe() - if err != nil { - t.Errorf("Unsubscribe failed with err: %s", err) + if err := sub.Unsubscribe(); !assert.NoError(err) { return } close(ch) @@ -89,18 +87,14 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { }() <-ready for i := int32(0); i < max; i++ { - if err := client.Publish("foo", []byte("hello")); err != nil { - t.Error(err) - } + assert.NoError(client.Publish("foo", []byte("hello"))) // Allow NATS goroutines to process messages. time.Sleep(10 * time.Millisecond) } <-ch - if r := received.Load(); r != max { - t.Fatalf("Received wrong # of messages: %d vs %d", r, max) - } + require.EqualValues(max, received.Load(), "Received wrong # of messages") } func TestNatsClient_Subscribe(t *testing.T) { @@ -115,9 +109,7 @@ func TestNatsClient_Subscribe(t *testing.T) { func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) { client.Close() - if err := client.Publish("foo", "bar"); err != nats.ErrConnectionClosed { - t.Errorf("Expected %v, got %v", nats.ErrConnectionClosed, err) - } + assert.ErrorIs(t, client.Publish("foo", "bar"), nats.ErrConnectionClosed) } func TestNatsClient_PublishAfterClose(t *testing.T) { @@ -133,9 +125,8 @@ func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) { client.Close() ch := make(chan *nats.Msg) - if _, err := client.Subscribe("foo", ch); err != nats.ErrConnectionClosed { - t.Errorf("Expected %v, got %v", nats.ErrConnectionClosed, err) - } + _, err := client.Subscribe("foo", ch) + assert.ErrorIs(t, err, nats.ErrConnectionClosed) } func TestNatsClient_SubscribeAfterClose(t *testing.T) { @@ -148,6 +139,7 @@ func TestNatsClient_SubscribeAfterClose(t *testing.T) { } func testNatsClient_BadSubjects(t *testing.T, client NatsClient) { + assert := assert.New(t) subjects := []string{ "foo bar", "foo.", @@ -155,9 +147,8 @@ func testNatsClient_BadSubjects(t *testing.T, client NatsClient) { ch := make(chan *nats.Msg) for _, s := range subjects { - if _, err := client.Subscribe(s, ch); err != nats.ErrBadSubject { - t.Errorf("Expected %v for subject %s, got %v", nats.ErrBadSubject, s, err) - } + _, err := client.Subscribe(s, ch) + assert.ErrorIs(err, nats.ErrBadSubject, "Expected error for subject %s", s) } } diff --git a/notifier_test.go b/notifier_test.go index 51fc9207..c40d7f00 100644 --- a/notifier_test.go +++ b/notifier_test.go @@ -26,6 +26,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestNotifierNoWaiter(t *testing.T) { @@ -48,9 +50,7 @@ func TestNotifierSimple(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := waiter.Wait(ctx); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(ctx)) }() notifier.Notify("foo") @@ -74,9 +74,7 @@ func TestNotifierWaitClosed(t *testing.T) { waiter := notifier.NewWaiter("foo") notifier.Release(waiter) - if err := waiter.Wait(context.Background()); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(context.Background())) } func TestNotifierWaitClosedMulti(t *testing.T) { @@ -87,12 +85,8 @@ func TestNotifierWaitClosedMulti(t *testing.T) { notifier.Release(waiter1) notifier.Release(waiter2) - if err := waiter1.Wait(context.Background()); err != nil { - t.Error(err) - } - if err := waiter2.Wait(context.Background()); err != nil { - t.Error(err) - } + assert.NoError(t, waiter1.Wait(context.Background())) + assert.NoError(t, waiter2.Wait(context.Background())) } func TestNotifierResetWillNotify(t *testing.T) { @@ -108,9 +102,7 @@ func TestNotifierResetWillNotify(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := waiter.Wait(ctx); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(ctx)) }() notifier.Reset() @@ -137,9 +129,7 @@ func TestNotifierDuplicate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := waiter.Wait(ctx); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(ctx)) }() } diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 9af71c7d..2191ddb6 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -38,6 +38,8 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" signaling "github.com/strukturag/nextcloud-spreed-signaling" ) @@ -57,6 +59,7 @@ func getWebsocketUrl(url string) string { } func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httptest.Server) { + require := require.New(t) tempdir := t.TempDir() var proxy *ProxyServer t.Cleanup(func() { @@ -67,43 +70,32 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httpte r := mux.NewRouter() key, err := rsa.GenerateKey(rand.Reader, KeypairSizeForTest) - if err != nil { - t.Fatalf("could not generate key: %s", err) - } + require.NoError(err) priv := &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), } privkey, err := os.CreateTemp(tempdir, "privkey*.pem") - if err != nil { - t.Fatalf("could not create temporary file for private key: %s", err) - } - if err := pem.Encode(privkey, priv); err != nil { - t.Fatalf("could not encode private key: %s", err) - } + require.NoError(err) + require.NoError(pem.Encode(privkey, priv)) + require.NoError(privkey.Close()) pubData, err := x509.MarshalPKIXPublicKey(&key.PublicKey) - if err != nil { - t.Fatalf("could not marshal public key: %s", err) - } + require.NoError(err) pub := &pem.Block{ Type: "RSA PUBLIC KEY", Bytes: pubData, } pubkey, err := os.CreateTemp(tempdir, "pubkey*.pem") - if err != nil { - t.Fatalf("could not create temporary file for public key: %s", err) - } - if err := pem.Encode(pubkey, pub); err != nil { - t.Fatalf("could not encode public key: %s", err) - } + require.NoError(err) + require.NoError(pem.Encode(pubkey, pub)) + require.NoError(pubkey.Close()) config := goconf.NewConfigFile() config.AddOption("tokens", TokenIdForTest, pubkey.Name()) - if proxy, err = NewProxyServer(r, "0.0", config); err != nil { - t.Fatalf("could not create proxy server: %s", err) - } + proxy, err = NewProxyServer(r, "0.0", config) + require.NoError(err) server := httptest.NewServer(r) t.Cleanup(func() { @@ -125,19 +117,14 @@ func TestTokenValid(t *testing.T) { } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenString, err := token.SignedString(key) - if err != nil { - t.Fatalf("could not create token: %s", err) - } + require.NoError(t, err) hello := &signaling.HelloProxyClientMessage{ Version: "1.0", Token: tokenString, } - session, err := proxy.NewSession(hello) - if session != nil { + if session, err := proxy.NewSession(hello); !assert.NoError(t, err) { defer session.Close() - } else if err != nil { - t.Error(err) } } @@ -153,20 +140,16 @@ func TestTokenNotSigned(t *testing.T) { } token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) - if err != nil { - t.Fatalf("could not create token: %s", err) - } + require.NoError(t, err) hello := &signaling.HelloProxyClientMessage{ Version: "1.0", Token: tokenString, } - session, err := proxy.NewSession(hello) - if session != nil { - defer session.Close() - t.Errorf("should not have created session") - } else if err != TokenAuthFailed { - t.Errorf("could have failed with TokenAuthFailed, got %s", err) + if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) { + if session != nil { + defer session.Close() + } } } @@ -182,20 +165,16 @@ func TestTokenUnknown(t *testing.T) { } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenString, err := token.SignedString(key) - if err != nil { - t.Fatalf("could not create token: %s", err) - } + require.NoError(t, err) hello := &signaling.HelloProxyClientMessage{ Version: "1.0", Token: tokenString, } - session, err := proxy.NewSession(hello) - if session != nil { - defer session.Close() - t.Errorf("should not have created session") - } else if err != TokenAuthFailed { - t.Errorf("could have failed with TokenAuthFailed, got %s", err) + if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) { + if session != nil { + defer session.Close() + } } } @@ -211,20 +190,16 @@ func TestTokenInFuture(t *testing.T) { } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenString, err := token.SignedString(key) - if err != nil { - t.Fatalf("could not create token: %s", err) - } + require.NoError(t, err) hello := &signaling.HelloProxyClientMessage{ Version: "1.0", Token: tokenString, } - session, err := proxy.NewSession(hello) - if session != nil { - defer session.Close() - t.Errorf("should not have created session") - } else if err != TokenNotValidYet { - t.Errorf("could have failed with TokenNotValidYet, got %s", err) + if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenNotValidYet) { + if session != nil { + defer session.Close() + } } } @@ -240,24 +215,21 @@ func TestTokenExpired(t *testing.T) { } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenString, err := token.SignedString(key) - if err != nil { - t.Fatalf("could not create token: %s", err) - } + require.NoError(t, err) hello := &signaling.HelloProxyClientMessage{ Version: "1.0", Token: tokenString, } - session, err := proxy.NewSession(hello) - if session != nil { - defer session.Close() - t.Errorf("should not have created session") - } else if err != TokenExpired { - t.Errorf("could have failed with TokenExpired, got %s", err) + if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenExpired) { + if session != nil { + defer session.Close() + } } } func TestPublicIPs(t *testing.T) { + assert := assert.New(t) public := []string{ "8.8.8.8", "172.15.1.2", @@ -275,35 +247,30 @@ func TestPublicIPs(t *testing.T) { } for _, s := range public { ip := net.ParseIP(s) - if len(ip) == 0 { - t.Errorf("invalid IP: %s", s) - } else if !IsPublicIP(ip) { - t.Errorf("should be public IP: %s", s) + if assert.NotEmpty(ip, "invalid IP: %s", s) { + assert.True(IsPublicIP(ip), "should be public IP: %s", s) } } for _, s := range private { ip := net.ParseIP(s) - if len(ip) == 0 { - t.Errorf("invalid IP: %s", s) - } else if IsPublicIP(ip) { - t.Errorf("should be private IP: %s", s) + if assert.NotEmpty(ip, "invalid IP: %s", s) { + assert.False(IsPublicIP(ip), "should be private IP: %s", s) } } } func TestWebsocketFeatures(t *testing.T) { signaling.CatchLogForTest(t) + assert := assert.New(t) _, _, server := newProxyServerForTest(t) conn, response, err := websocket.DefaultDialer.DialContext(context.Background(), getWebsocketUrl(server.URL), nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer conn.Close() // nolint if server := response.Header.Get("Server"); !strings.HasPrefix(server, "nextcloud-spreed-signaling-proxy/") { - t.Errorf("expected valid server header, got \"%s\"", server) + assert.Fail("expected valid server header, got \"%s\"", server) } features := response.Header.Get("X-Spreed-Signaling-Features") featuresList := make(map[string]bool) @@ -311,19 +278,15 @@ func TestWebsocketFeatures(t *testing.T) { f = strings.TrimSpace(f) if f != "" { if _, found := featuresList[f]; found { - t.Errorf("duplicate feature id \"%s\" in \"%s\"", f, features) + assert.Fail("duplicate feature id \"%s\" in \"%s\"", f, features) } featuresList[f] = true } } - if len(featuresList) == 0 { - t.Errorf("expected valid features header, got \"%s\"", features) - } + assert.NotEmpty(featuresList, "expected valid features header, got \"%s\"", features) if _, found := featuresList["remote-streams"]; !found { - t.Errorf("expected feature \"remote-streams\", got \"%s\"", features) + assert.Fail("expected feature \"remote-streams\", got \"%s\"", features) } - if err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}); err != nil { - t.Errorf("could not write close message: %s", err) - } + assert.NoError(conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{})) } diff --git a/proxy/proxy_tokens_etcd_test.go b/proxy/proxy_tokens_etcd_test.go index 625b7306..957f7d51 100644 --- a/proxy/proxy_tokens_etcd_test.go +++ b/proxy/proxy_tokens_etcd_test.go @@ -37,6 +37,8 @@ import ( "testing" "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/server/v3/embed" "go.etcd.io/etcd/server/v3/lease" @@ -73,9 +75,7 @@ func newEtcdForTesting(t *testing.T) *embed.Etcd { cfg.LogLevel = "warn" u, err := url.Parse(etcdListenUrl) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Find a free port to bind the server to. var etcd *embed.Etcd @@ -91,14 +91,12 @@ func newEtcdForTesting(t *testing.T) *embed.Etcd { etcd, err = embed.StartEtcd(cfg) if isErrorAddressAlreadyInUse(err) { continue - } else if err != nil { - t.Fatal(err) } + + require.NoError(t, err) break } - if etcd == nil { - t.Fatal("could not find free port") - } + require.NotNil(t, etcd, "could not find free port") t.Cleanup(func() { etcd.Close() @@ -118,9 +116,7 @@ func newTokensEtcdForTesting(t *testing.T) (*tokensEtcd, *embed.Etcd) { cfg.AddOption("tokens", "keyformat", "/%s, /testing/%s/key") tokens, err := NewProxyTokensEtcd(cfg) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { tokens.Close() }) @@ -134,11 +130,9 @@ func storeKey(t *testing.T, etcd *embed.Etcd, key string, pubkey crypto.PublicKe switch pubkey := pubkey.(type) { case rsa.PublicKey: data, err = x509.MarshalPKIXPublicKey(&pubkey) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) default: - t.Fatalf("unknown key type %T in %+v", pubkey, pubkey) + require.Fail(t, "unknown key type %T in %+v", pubkey, pubkey) } data = pem.EncodeToMemory(&pem.Block{ @@ -154,9 +148,7 @@ func storeKey(t *testing.T, etcd *embed.Etcd, key string, pubkey crypto.PublicKe func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.PrivateKey { key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) storeKey(t, etcd, name, key.PublicKey) return key @@ -164,24 +156,17 @@ func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.Privat func TestProxyTokensEtcd(t *testing.T) { signaling.CatchLogForTest(t) + assert := assert.New(t) tokens, etcd := newTokensEtcdForTesting(t) key1 := generateAndSaveKey(t, etcd, "/foo") key2 := generateAndSaveKey(t, etcd, "/testing/bar/key") - if token, err := tokens.Get("foo"); err != nil { - t.Error(err) - } else if token == nil { - t.Error("could not get token") - } else if !key1.PublicKey.Equal(token.key) { - t.Error("token keys mismatch") + if token, err := tokens.Get("foo"); assert.NoError(err) && assert.NotNil(token) { + assert.True(key1.PublicKey.Equal(token.key)) } - if token, err := tokens.Get("bar"); err != nil { - t.Error(err) - } else if token == nil { - t.Error("could not get token") - } else if !key2.PublicKey.Equal(token.key) { - t.Error("token keys mismatch") + if token, err := tokens.Get("bar"); assert.NoError(err) && assert.NotNil(token) { + assert.True(key2.PublicKey.Equal(token.key)) } } diff --git a/proxy_config_etcd_test.go b/proxy_config_etcd_test.go index 2276f8bd..353f690e 100644 --- a/proxy_config_etcd_test.go +++ b/proxy_config_etcd_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/server/v3/embed" ) @@ -43,9 +44,7 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig) cfg := goconf.NewConfigFile() cfg.AddOption("mcu", "keyprefix", "proxies/") p, err := NewProxyConfigEtcd(cfg, client, proxy) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { p.Stop() }) @@ -55,9 +54,7 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig) func SetEtcdProxy(t *testing.T, etcd *embed.Etcd, path string, proxy *TestProxyInformationEtcd) { t.Helper() data, err := json.Marshal(proxy) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) SetEtcdValue(etcd, path, data) } @@ -74,9 +71,7 @@ func TestProxyConfigEtcd(t *testing.T) { Address: "https://foo/", }) proxy.Expect("add", "https://foo/") - if err := config.Start(); err != nil { - t.Fatal(err) - } + require.NoError(t, config.Start()) proxy.WaitForEvents(ctx) proxy.Expect("add", "https://bar/") diff --git a/proxy_config_static_test.go b/proxy_config_static_test.go index ebb8ec2c..70884e8e 100644 --- a/proxy_config_static_test.go +++ b/proxy_config_static_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/stretchr/testify/require" ) func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) (ProxyConfig, *DnsMonitor) { @@ -38,9 +39,7 @@ func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string } dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually p, err := NewProxyConfigStatic(cfg, proxy, dnsMonitor) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { p.Stop() }) @@ -53,9 +52,7 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls .. if dns { cfg.AddOption("mcu", "dnsdiscovery", "true") } - if err := config.Reload(cfg); err != nil { - t.Fatal(err) - } + require.NoError(t, config.Reload(cfg)) } func TestProxyConfigStaticSimple(t *testing.T) { @@ -63,9 +60,7 @@ func TestProxyConfigStaticSimple(t *testing.T) { proxy := newMcuProxyForConfig(t) config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/") proxy.Expect("add", "https://foo/") - if err := config.Start(); err != nil { - t.Fatal(err) - } + require.NoError(t, config.Start()) proxy.Expect("keep", "https://foo/") proxy.Expect("add", "https://bar/") @@ -82,9 +77,7 @@ func TestProxyConfigStaticDNS(t *testing.T) { lookup := newMockDnsLookupForTest(t) proxy := newMcuProxyForConfig(t) config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/") - if err := config.Start(); err != nil { - t.Fatal(err) - } + require.NoError(t, config.Start()) time.Sleep(time.Millisecond) diff --git a/proxy_config_test.go b/proxy_config_test.go index 5a3db60d..47f51add 100644 --- a/proxy_config_test.go +++ b/proxy_config_test.go @@ -29,6 +29,8 @@ import ( "strings" "sync" "testing" + + "github.com/stretchr/testify/assert" ) var ( @@ -61,9 +63,7 @@ func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig { t: t, } t.Cleanup(func() { - if len(proxy.expected) > 0 { - t.Errorf("expected events %+v were not triggered", proxy.expected) - } + assert.Empty(t, proxy.expected) }) return proxy } @@ -99,7 +99,7 @@ func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) { defer p.mu.Lock() select { case <-ctx.Done(): - p.t.Error(ctx.Err()) + assert.NoError(p.t, ctx.Err()) case <-waiter: } } @@ -125,7 +125,7 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) { defer p.mu.Unlock() if len(p.expected) == 0 { - p.t.Errorf("no event expected, got %+v from %s:%d", event, caller.File, caller.Line) + assert.Fail(p.t, "no event expected, got %+v from %s:%d", event, caller.File, caller.Line) return } @@ -145,7 +145,7 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) { expected := p.expected[0] p.expected = p.expected[1:] if !reflect.DeepEqual(expected, *event) { - p.t.Errorf("expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line) + assert.Fail(p.t, "expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line) } } diff --git a/room_ping_test.go b/room_ping_test.go index 1b9ec766..0bc775d4 100644 --- a/room_ping_test.go +++ b/room_ping_test.go @@ -28,9 +28,12 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) { + require := require.New(t) r := mux.NewRouter() registerBackendHandler(t, r) @@ -40,30 +43,23 @@ func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) { }) config, err := getTestConfig(server) - if err != nil { - t.Fatal(err) - } + require.NoError(err) backend, err := NewBackendClient(config, 1, "0.0", nil) - if err != nil { - t.Fatal(err) - } + require.NoError(err) p, err := NewRoomPing(backend, backend.capabilities) - if err != nil { - t.Fatal(err) - } + require.NoError(err) u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } + require.NoError(err) return u, p } func TestSingleRoomPing(t *testing.T) { CatchLogForTest(t) + assert := assert.New(t) u, ping := NewRoomPingForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -78,13 +74,9 @@ func TestSingleRoomPing(t *testing.T) { SessionId: "123", }, } - if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 1 { - t.Errorf("expected one ping request, got %+v", requests) - } else if len(requests[0].Ping.Entries) != 1 { - t.Errorf("expected one entry, got %+v", requests[0].Ping.Entries) + assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1)) + if requests := getPingRequests(t); assert.Len(requests, 1) { + assert.Len(requests[0].Ping.Entries, 1) } clearPingRequests(t) @@ -97,24 +89,19 @@ func TestSingleRoomPing(t *testing.T) { SessionId: "456", }, } - if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 1 { - t.Errorf("expected one ping request, got %+v", requests) - } else if len(requests[0].Ping.Entries) != 1 { - t.Errorf("expected one entry, got %+v", requests[0].Ping.Entries) + assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2)) + if requests := getPingRequests(t); assert.Len(requests, 1) { + assert.Len(requests[0].Ping.Entries, 1) } clearPingRequests(t) ping.publishActiveSessions() - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.Empty(getPingRequests(t)) } func TestMultiRoomPing(t *testing.T) { CatchLogForTest(t) + assert := assert.New(t) u, ping := NewRoomPingForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -129,12 +116,8 @@ func TestMultiRoomPing(t *testing.T) { SessionId: "123", }, } - if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1)) + assert.Empty(getPingRequests(t)) room2 := &Room{ id: "sample-room-2", @@ -145,23 +128,18 @@ func TestMultiRoomPing(t *testing.T) { SessionId: "456", }, } - if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2)) + assert.Empty(getPingRequests(t)) ping.publishActiveSessions() - if requests := getPingRequests(t); len(requests) != 1 { - t.Errorf("expected one ping request, got %+v", requests) - } else if len(requests[0].Ping.Entries) != 2 { - t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries) + if requests := getPingRequests(t); assert.Len(requests, 1) { + assert.Len(requests[0].Ping.Entries, 2) } } func TestMultiRoomPing_Separate(t *testing.T) { CatchLogForTest(t) + assert := assert.New(t) u, ping := NewRoomPingForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -176,35 +154,26 @@ func TestMultiRoomPing_Separate(t *testing.T) { SessionId: "123", }, } - if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1)) + assert.Empty(getPingRequests(t)) entries2 := []BackendPingEntry{ { UserId: "bar", SessionId: "456", }, } - if err := ping.SendPings(ctx, room1.Id(), u, entries2); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries2)) + assert.Empty(getPingRequests(t)) ping.publishActiveSessions() - if requests := getPingRequests(t); len(requests) != 1 { - t.Errorf("expected one ping request, got %+v", requests) - } else if len(requests[0].Ping.Entries) != 2 { - t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries) + if requests := getPingRequests(t); assert.Len(requests, 1) { + assert.Len(requests[0].Ping.Entries, 2) } } func TestMultiRoomPing_DeleteRoom(t *testing.T) { CatchLogForTest(t) + assert := assert.New(t) u, ping := NewRoomPingForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -219,12 +188,8 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) { SessionId: "123", }, } - if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1)) + assert.Empty(getPingRequests(t)) room2 := &Room{ id: "sample-room-2", @@ -235,19 +200,13 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) { SessionId: "456", }, } - if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil { - t.Error(err) - } - if requests := getPingRequests(t); len(requests) != 0 { - t.Errorf("expected no ping requests, got %+v", requests) - } + assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2)) + assert.Empty(getPingRequests(t)) ping.DeleteRoom(room2.Id()) ping.publishActiveSessions() - if requests := getPingRequests(t); len(requests) != 1 { - t.Errorf("expected one ping request, got %+v", requests) - } else if len(requests[0].Ping.Entries) != 1 { - t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries) + if requests := getPingRequests(t); assert.Len(requests, 1) { + assert.Len(requests[0].Ping.Entries, 1) } } diff --git a/room_test.go b/room_test.go index cb0abbde..13bfa12d 100644 --- a/room_test.go +++ b/room_test.go @@ -27,11 +27,14 @@ import ( "encoding/json" "fmt" "io" + "net/http" "strconv" "testing" "time" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRoom_InCall(t *testing.T) { @@ -63,59 +66,47 @@ func TestRoom_InCall(t *testing.T) { } for _, test := range tests { inCall, ok := IsInCall(test.Value) - if ok != test.Valid { - t.Errorf("%+v should be valid %v, got %v", test.Value, test.Valid, ok) - } - if inCall != test.InCall { - t.Errorf("%+v should convert to %v, got %v", test.Value, test.InCall, inCall) + if test.Valid { + assert.True(t, ok, "%+v should be valid", test.Value) + } else { + assert.False(t, ok, "%+v should not be valid", test.Value) } + assert.EqualValues(t, test.InCall, inCall, "conversion failed for %+v", test.Value) } } func TestRoom_Update(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) - if err != nil { - t.Fatal(err) - } + require.NoError(err) b, err := NewBackendServer(config, hub, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b.Start(router); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b.Start(router)) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Simulate backend request from Nextcloud to update the room. roomProperties := json.RawMessage("{\"foo\":\"bar\"}") @@ -130,54 +121,32 @@ func TestRoom_Update(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) // The client receives a roomlist update and a changed room event. The // ordering is not defined because messages are sent by asynchronous event // handlers. message1, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } + assert.NoError(err) message2, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } + assert.NoError(err) if msg, err := checkMessageRoomlistUpdate(message1); err != nil { - if err := checkMessageRoomId(message1, roomId); err != nil { - t.Error(err) - } - if msg, err := checkMessageRoomlistUpdate(message2); err != nil { - t.Error(err) - } else if msg.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, msg) - } else if len(msg.Properties) == 0 || !bytes.Equal(msg.Properties, roomProperties) { - t.Errorf("Expected room properties %s, got %+v", string(roomProperties), msg) + assert.NoError(checkMessageRoomId(message1, roomId)) + if msg, err := checkMessageRoomlistUpdate(message2); assert.NoError(err) { + assert.Equal(roomId, msg.RoomId) + assert.Equal(string(roomProperties), string(msg.Properties)) } } else { - if msg.RoomId != roomId { - t.Errorf("Expected room id %s, got %+v", roomId, msg) - } else if len(msg.Properties) == 0 || !bytes.Equal(msg.Properties, roomProperties) { - t.Errorf("Expected room properties %s, got %+v", string(roomProperties), msg) - } - if err := checkMessageRoomId(message2, roomId); err != nil { - t.Error(err) - } + assert.Equal(roomId, msg.RoomId) + assert.Equal(string(roomProperties), string(msg.Properties)) + assert.NoError(checkMessageRoomId(message2, roomId)) } // Allow up to 100 milliseconds for asynchronous event processing. @@ -206,55 +175,41 @@ loop: time.Sleep(time.Millisecond) } - if err != nil { - t.Error(err) - } + assert.NoError(err) } func TestRoom_Delete(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) - if err != nil { - t.Fatal(err) - } + require.NoError(err) b, err := NewBackendServer(config, hub, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b.Start(router); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b.Start(router)) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event. - if err := client.RunUntilJoined(ctx, hello.Hello); err != nil { - t.Error(err) - } + assert.NoError(client.RunUntilJoined(ctx, hello.Hello)) // Simulate backend request from Nextcloud to update the room. msg := &BackendServerRoomRequest{ @@ -267,47 +222,31 @@ func TestRoom_Delete(t *testing.T) { } data, err := json.Marshal(msg) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res.Body.Close() body, err := io.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if res.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body)) // The client is no longer invited to the room and leaves it. The ordering // of messages is not defined as they get published through events and handled // by asynchronous channels. message1, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } + assert.NoError(err) if err := checkMessageType(message1, "event"); err != nil { // Ordering should be "leave room", "disinvited". - if err := checkMessageRoomId(message1, ""); err != nil { - t.Error(err) - } - message2, err := client.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - if _, err := checkMessageRoomlistDisinvite(message2); err != nil { - t.Error(err) + assert.NoError(checkMessageRoomId(message1, "")) + if message2, err := client.RunUntilMessage(ctx); assert.NoError(err) { + _, err := checkMessageRoomlistDisinvite(message2) + assert.NoError(err) } } else { // Ordering should be "disinvited", "leave room". - if _, err := checkMessageRoomlistDisinvite(message1); err != nil { - t.Error(err) - } + _, err := checkMessageRoomlistDisinvite(message1) + assert.NoError(err) message2, err := client.RunUntilMessage(ctx) if err != nil { // The connection should get closed after the "disinvited". @@ -315,10 +254,10 @@ func TestRoom_Delete(t *testing.T) { websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - t.Error(err) + assert.NoError(err) } - } else if err := checkMessageRoomId(message2, ""); err != nil { - t.Error(err) + } else { + assert.NoError(checkMessageRoomId(message2, "")) } } @@ -350,151 +289,106 @@ loop: time.Sleep(time.Millisecond) } - if err != nil { - t.Error(err) - } + assert.NoError(err) } func TestRoom_RoomSessionData(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) - if err != nil { - t.Fatal(err) - } + require.NoError(err) b, err := NewBackendServer(config, hub, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b.Start(router); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b.Start(router)) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(authAnonymousUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(authAnonymousUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room-with-sessiondata" - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // We will receive a "joined" event with the userid from the room session data. expected := "userid-from-sessiondata" - if message, err := client.RunUntilMessage(ctx); err != nil { - t.Error(err) - } else if err := client.checkMessageJoinedSession(message, hello.Hello.SessionId, expected); err != nil { - t.Error(err) - } else if message.Event.Join[0].RoomSessionId != roomId+"-"+hello.Hello.SessionId { - t.Errorf("Expected join room session id %s, got %+v", roomId+"-"+hello.Hello.SessionId, message.Event.Join[0]) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + if assert.NoError(client.checkMessageJoinedSession(message, hello.Hello.SessionId, expected)) { + assert.Equal(roomId+"-"+hello.Hello.SessionId, message.Event.Join[0].RoomSessionId) + } } session := hub.GetSessionByPublicId(hello.Hello.SessionId) - if session == nil { - t.Fatalf("Could not find session %s", hello.Hello.SessionId) - } - - if userid := session.UserId(); userid != expected { - t.Errorf("Expected userid %s, got %s", expected, userid) - } + require.NotNil(session, "Could not find session %s", hello.Hello.SessionId) + assert.Equal(expected, session.UserId()) room := hub.getRoom(roomId) - if room == nil { - t.Fatalf("Room not found") - } + assert.NotNil(room, "Room not found") entries, wg := room.publishActiveSessions() - if entries != 1 { - t.Errorf("expected 1 entries, got %d", entries) - } + assert.Equal(1, entries) wg.Wait() } func TestRoom_InCallAll(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) - if err != nil { - t.Fatal(err) - } + require.NoError(err) b, err := NewBackendServer(config, hub, "no-version") - if err != nil { - t.Fatal(err) - } - if err := b.Start(router); err != nil { - t.Fatal(err) - } + require.NoError(err) + require.NoError(b.Start(router)) client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello)) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) - if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) - if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } + assert.NoError(client1.RunUntilJoined(ctx, hello2.Hello)) // Simulate backend request from Nextcloud to update the "inCall" flag of all participants. msg1 := &BackendServerRoomRequest{ @@ -506,32 +400,20 @@ func TestRoom_InCallAll(t *testing.T) { } data1, err := json.Marshal(msg1) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res1, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data1) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res1.Body.Close() body1, err := io.ReadAll(res1.Body) - if err != nil { - t.Error(err) - } - if res1.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res1.Status, string(body1)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res1.StatusCode, "Expected successful request, got %s", string(body1)) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else if err := checkMessageInCallAll(msg, roomId, FlagInCall); err != nil { - t.Fatal(err) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageInCallAll(msg, roomId, FlagInCall)) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else if err := checkMessageInCallAll(msg, roomId, FlagInCall); err != nil { - t.Fatal(err) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageInCallAll(msg, roomId, FlagInCall)) } // Simulate backend request from Nextcloud to update the "inCall" flag of all participants. @@ -544,31 +426,19 @@ func TestRoom_InCallAll(t *testing.T) { } data2, err := json.Marshal(msg2) - if err != nil { - t.Fatal(err) - } + require.NoError(err) res2, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data2) - if err != nil { - t.Fatal(err) - } + require.NoError(err) defer res2.Body.Close() body2, err := io.ReadAll(res2.Body) - if err != nil { - t.Error(err) - } - if res2.StatusCode != 200 { - t.Errorf("Expected successful request, got %s: %s", res2.Status, string(body2)) - } + assert.NoError(err) + assert.Equal(http.StatusOK, res2.StatusCode, "Expected successful request, got %s", string(body2)) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else if err := checkMessageInCallAll(msg, roomId, 0); err != nil { - t.Fatal(err) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageInCallAll(msg, roomId, 0)) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else if err := checkMessageInCallAll(msg, roomId, 0); err != nil { - t.Fatal(err) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(checkMessageInCallAll(msg, roomId, 0)) } } diff --git a/roomsessions_builtin_test.go b/roomsessions_builtin_test.go index db1394ce..c69e3460 100644 --- a/roomsessions_builtin_test.go +++ b/roomsessions_builtin_test.go @@ -23,13 +23,13 @@ package signaling import ( "testing" + + "github.com/stretchr/testify/require" ) func TestBuiltinRoomSessions(t *testing.T) { sessions, err := NewBuiltinRoomSessions(nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) testRoomSessions(t, sessions) } diff --git a/roomsessions_test.go b/roomsessions_test.go index f7ba3e46..65018199 100644 --- a/roomsessions_test.go +++ b/roomsessions_test.go @@ -24,9 +24,11 @@ package signaling import ( "context" "encoding/json" - "errors" "net/url" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type DummySession struct { @@ -107,70 +109,53 @@ func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSes session := &DummySession{ publicId: sessionId, } - if err := sessions.SetRoomSession(session, roomSessionId); err != nil { - t.Fatalf("Expected no error, got %s", err) - } - if sid, err := sessions.GetSessionId(roomSessionId); err != nil { - t.Errorf("Expected session id %s, got error %s", sessionId, err) - } else if sid != sessionId { - t.Errorf("Expected session id %s, got %s", sessionId, sid) + require.NoError(t, sessions.SetRoomSession(session, roomSessionId)) + if sid, err := sessions.GetSessionId(roomSessionId); assert.NoError(t, err) { + assert.Equal(t, sessionId, sid) } return session } func testRoomSessions(t *testing.T, sessions RoomSessions) { - if sid, err := sessions.GetSessionId("unknown"); err != nil && err != ErrNoSuchRoomSession { - t.Errorf("Expected error about invalid room session, got %s", err) - } else if err == nil { - t.Errorf("Expected error about invalid room session, got session id %s", sid) + assert := assert.New(t) + if sid, err := sessions.GetSessionId("unknown"); err == nil { + assert.Fail("Expected error about invalid room session, got session id %s", sid) + } else { + assert.ErrorIs(err, ErrNoSuchRoomSession) } s1 := checkSession(t, sessions, "session1", "room1") s2 := checkSession(t, sessions, "session2", "room2") - if sid, err := sessions.GetSessionId("room1"); err != nil { - t.Errorf("Expected session id %s, got error %s", s1.PublicId(), err) - } else if sid != s1.PublicId() { - t.Errorf("Expected session id %s, got %s", s1.PublicId(), sid) + if sid, err := sessions.GetSessionId("room1"); assert.NoError(err) { + assert.Equal(s1.PublicId(), sid) } sessions.DeleteRoomSession(s1) - if sid, err := sessions.GetSessionId("room1"); err != nil && err != ErrNoSuchRoomSession { - t.Errorf("Expected error about invalid room session, got %s", err) - } else if err == nil { - t.Errorf("Expected error about invalid room session, got session id %s", sid) + if sid, err := sessions.GetSessionId("room1"); err == nil { + assert.Fail("Expected error about invalid room session, got session id %s", sid) + } else { + assert.ErrorIs(err, ErrNoSuchRoomSession) } - if sid, err := sessions.GetSessionId("room2"); err != nil { - t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err) - } else if sid != s2.PublicId() { - t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid) + if sid, err := sessions.GetSessionId("room2"); assert.NoError(err) { + assert.Equal(s2.PublicId(), sid) } - if err := sessions.SetRoomSession(s1, "room-session"); err != nil { - t.Error(err) - } - if err := sessions.SetRoomSession(s2, "room-session"); err != nil { - t.Error(err) - } + assert.NoError(sessions.SetRoomSession(s1, "room-session")) + assert.NoError(sessions.SetRoomSession(s2, "room-session")) sessions.DeleteRoomSession(s1) - if sid, err := sessions.GetSessionId("room-session"); err != nil { - t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err) - } else if sid != s2.PublicId() { - t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid) + if sid, err := sessions.GetSessionId("room-session"); assert.NoError(err) { + assert.Equal(s2.PublicId(), sid) } - if err := sessions.SetRoomSession(s2, "room-session2"); err != nil { - t.Error(err) - } + assert.NoError(sessions.SetRoomSession(s2, "room-session2")) if sid, err := sessions.GetSessionId("room-session"); err == nil { - t.Errorf("expected error %s, got sid %s", ErrNoSuchRoomSession, sid) - } else if !errors.Is(err, ErrNoSuchRoomSession) { - t.Errorf("expected %s, got %s", ErrNoSuchRoomSession, err) + assert.Fail("Expected error about invalid room session, got session id %s", sid) + } else { + assert.ErrorIs(err, ErrNoSuchRoomSession) } - if sid, err := sessions.GetSessionId("room-session2"); err != nil { - t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err) - } else if sid != s2.PublicId() { - t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid) + if sid, err := sessions.GetSessionId("room-session2"); assert.NoError(err) { + assert.Equal(s2.PublicId(), sid) } } diff --git a/session_test.go b/session_test.go index 3377472a..b157c80e 100644 --- a/session_test.go +++ b/session_test.go @@ -23,18 +23,16 @@ package signaling import ( "testing" + + "github.com/stretchr/testify/assert" ) func assertSessionHasPermission(t *testing.T, session Session, permission Permission) { t.Helper() - if !session.HasPermission(permission) { - t.Errorf("Session %s doesn't have permission %s", session.PublicId(), permission) - } + assert.True(t, session.HasPermission(permission), "Session %s doesn't have permission %s", session.PublicId(), permission) } func assertSessionHasNotPermission(t *testing.T, session Session, permission Permission) { t.Helper() - if session.HasPermission(permission) { - t.Errorf("Session %s has permission %s but shouldn't", session.PublicId(), permission) - } + assert.False(t, session.HasPermission(permission), "Session %s has permission %s but shouldn't", session.PublicId(), permission) } diff --git a/single_notifier_test.go b/single_notifier_test.go index 48a1c5a4..7b95bebc 100644 --- a/single_notifier_test.go +++ b/single_notifier_test.go @@ -26,6 +26,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestSingleNotifierNoWaiter(t *testing.T) { @@ -48,9 +50,7 @@ func TestSingleNotifierSimple(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := waiter.Wait(ctx); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(ctx)) }() notifier.Notify() @@ -74,9 +74,7 @@ func TestSingleNotifierWaitClosed(t *testing.T) { waiter := notifier.NewWaiter() notifier.Release(waiter) - if err := waiter.Wait(context.Background()); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(context.Background())) } func TestSingleNotifierWaitClosedMulti(t *testing.T) { @@ -87,12 +85,8 @@ func TestSingleNotifierWaitClosedMulti(t *testing.T) { notifier.Release(waiter1) notifier.Release(waiter2) - if err := waiter1.Wait(context.Background()); err != nil { - t.Error(err) - } - if err := waiter2.Wait(context.Background()); err != nil { - t.Error(err) - } + assert.NoError(t, waiter1.Wait(context.Background())) + assert.NoError(t, waiter2.Wait(context.Background())) } func TestSingleNotifierResetWillNotify(t *testing.T) { @@ -108,9 +102,7 @@ func TestSingleNotifierResetWillNotify(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := waiter.Wait(ctx); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(ctx)) }() notifier.Reset() @@ -137,9 +129,7 @@ func TestSingleNotifierDuplicate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := waiter.Wait(ctx); err != nil { - t.Error(err) - } + assert.NoError(t, waiter.Wait(ctx)) }() } diff --git a/stats_prometheus_test.go b/stats_prometheus_test.go index 08d6a7a6..626777af 100644 --- a/stats_prometheus_test.go +++ b/stats_prometheus_test.go @@ -29,6 +29,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" ) func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64) { @@ -37,10 +38,11 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64 desc := <-ch v := testutil.ToFloat64(collector) if v != value { + assert := assert.New(t) pc := make([]uintptr, 10) n := runtime.Callers(2, pc) if n == 0 { - t.Errorf("Expected value %f for %s, got %f", value, desc, v) + assert.Fail("Expected value %f for %s, got %f", value, desc, v) return } @@ -57,20 +59,20 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64 break } } - t.Errorf("Expected value %f for %s, got %f at\n%s", value, desc, v, stack) + assert.Fail("Expected value %f for %s, got %f at\n%s", value, desc, v, stack) } } func collectAndLint(t *testing.T, collectors ...prometheus.Collector) { + assert := assert.New(t) for _, collector := range collectors { problems, err := testutil.CollectAndLint(collector) - if err != nil { - t.Errorf("Error linting %+v: %s", collector, err) + if !assert.NoError(err) { continue } for _, problem := range problems { - t.Errorf("Problem with %s: %s", problem.Metric, problem.Text) + assert.Fail("Problem with %s: %s", problem.Metric, problem.Text) } } } diff --git a/testclient_test.go b/testclient_test.go index 2c02ad5a..e02422b7 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -40,6 +40,8 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -225,9 +227,7 @@ type TestClient struct { func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient { // Reference "hub" to prevent compiler error. conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) messageChan := make(chan []byte) readErrorChan := make(chan error, 1) @@ -238,8 +238,7 @@ func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Se if err != nil { readErrorChan <- err return - } else if messageType != websocket.TextMessage { - t.Errorf("Expect text message, got %d", messageType) + } else if !assert.Equal(t, websocket.TextMessage, messageType) { return } @@ -266,13 +265,8 @@ func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient client := NewTestClientContext(ctx, t, server, hub) msg, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - - if msg.Type != "welcome" { - t.Errorf("Expected welcome message, got %+v", msg) - } + require.NoError(t, err) + assert.Equal(t, "welcome", msg.Type) return client } @@ -380,9 +374,7 @@ func (c *TestClient) WriteJSON(data interface{}) error { } func (c *TestClient) EnsuerWriteJSON(data interface{}) { - if err := c.WriteJSON(data); err != nil { - c.t.Fatalf("Could not write JSON %+v: %s", data, err) - } + require.NoError(c.t, c.WriteJSON(data), "Could not write JSON %+v", data) } func (c *TestClient) SendHello(userid string) error { @@ -443,9 +435,7 @@ func (c *TestClient) CreateHelloV2Token(userid string, issuedAt time.Time, expir func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error { tokenString, err := c.CreateHelloV2Token(userid, issuedAt, expiresAt) - if err != nil { - c.t.Fatal(err) - } + require.NoError(c.t, err) params := HelloV2AuthParams{ Token: tokenString, @@ -493,9 +483,7 @@ func (c *TestClient) SendHelloInternalWithFeatures(features []string) error { func (c *TestClient) SendHelloParams(url string, version string, clientType string, features []string, params interface{}) error { data, err := json.Marshal(params) - if err != nil { - c.t.Fatal(err) - } + require.NoError(c.t, err) hello := &ClientMessage{ Id: "1234", @@ -524,9 +512,7 @@ func (c *TestClient) SendBye() error { func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data interface{}) error { payload, err := json.Marshal(data) - if err != nil { - c.t.Fatal(err) - } + require.NoError(c.t, err) message := &ClientMessage{ Id: "abcd", @@ -541,9 +527,7 @@ func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data i func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data interface{}) error { payload, err := json.Marshal(data) - if err != nil { - c.t.Fatal(err) - } + require.NoError(c.t, err) message := &ClientMessage{ Id: "abcd", @@ -608,9 +592,7 @@ func (c *TestClient) SendInternalDialout(msg *DialoutInternalClientMessage) erro func (c *TestClient) SetTransientData(key string, value interface{}, ttl time.Duration) error { payload, err := json.Marshal(value) - if err != nil { - c.t.Fatal(err) - } + require.NoError(c.t, err) message := &ClientMessage{ Id: "efgh", diff --git a/testutils_test.go b/testutils_test.go index 862cbedd..f2d507a4 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -30,6 +30,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/require" ) var listenSignalOnce sync.Once @@ -76,7 +78,7 @@ func ensureNoGoroutinesLeak(t *testing.T, f func(t *testing.T)) { if after != before { io.Copy(os.Stderr, &prev) // nolint dumpGoroutines("After:", os.Stderr) - t.Fatalf("Number of Go routines has changed from %d to %d", before, after) + require.Equal(t, before, after, "Number of Go routines has changed") } } diff --git a/throttle_test.go b/throttle_test.go index 23efae07..a95102a7 100644 --- a/throttle_test.go +++ b/throttle_test.go @@ -25,14 +25,15 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newMemoryThrottlerForTest(t *testing.T) *memoryThrottler { t.Helper() result, err := NewMemoryThrottler() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { result.Close() @@ -54,12 +55,11 @@ func (t *throttlerTiming) getNow() time.Time { func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) { t.t.Helper() - if duration != t.expectedSleep { - t.t.Errorf("expected sleep %s, got %s", t.expectedSleep, duration) - } + assert.Equal(t.t, t.expectedSleep, duration) } func TestThrottler(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -71,38 +71,31 @@ func TestThrottler(t *testing.T) { ctx := context.Background() throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle1(ctx) timing.now = timing.now.Add(time.Millisecond) throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle2(ctx) timing.now = timing.now.Add(time.Millisecond) throttle3, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle3(ctx) timing.now = timing.now.Add(time.Millisecond) throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle4(ctx) } func TestThrottlerIPv6(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -115,40 +108,33 @@ func TestThrottlerIPv6(t *testing.T) { // Make sure full /64 subnets are throttled for IPv6. throttle1, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle1(ctx) timing.now = timing.now.Add(time.Millisecond) throttle2, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::2", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle2(ctx) // A diffent /64 subnet is not throttled yet. timing.now = timing.now.Add(time.Millisecond) throttle3, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0013::1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle3(ctx) // A different action is not throttled. timing.now = timing.now.Add(time.Millisecond) throttle4, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action2") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle4(ctx) } func TestThrottler_Bruteforce(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -162,9 +148,7 @@ func TestThrottler_Bruteforce(t *testing.T) { for i := 0; i < maxBruteforceAttempts; i++ { timing.now = timing.now.Add(time.Millisecond) throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) if i == 0 { timing.expectedSleep = 100 * time.Millisecond } else { @@ -177,14 +161,12 @@ func TestThrottler_Bruteforce(t *testing.T) { } timing.now = timing.now.Add(time.Millisecond) - if _, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1"); err == nil { - t.Error("expected bruteforce error") - } else if err != ErrBruteforceDetected { - t.Errorf("expected error %s, got %s", ErrBruteforceDetected, err) - } + _, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + assert.ErrorIs(err, ErrBruteforceDetected) } func TestThrottler_Cleanup(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -196,59 +178,46 @@ func TestThrottler_Cleanup(t *testing.T) { ctx := context.Background() throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle1(ctx) throttle2, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle2(ctx) timing.now = timing.now.Add(time.Hour) throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle3(ctx) throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle4(ctx) timing.now = timing.now.Add(-time.Hour).Add(maxBruteforceAge).Add(time.Second) th.cleanup(timing.now) - if entries := th.getEntries("192.168.0.1", "action1"); len(entries) != 1 { - t.Errorf("should have removed one entry, got %+v", entries) - } - if entries := th.getEntries("192.168.0.1", "action2"); len(entries) != 1 { - t.Errorf("should have kept entry, got %+v", entries) - } + assert.Len(th.getEntries("192.168.0.1", "action1"), 1) + assert.Len(th.getEntries("192.168.0.1", "action2"), 1) th.mu.RLock() if _, found := th.clients["192.168.0.2"]; found { - t.Error("should have removed client \"192.168.0.2\"") + assert.Fail("should have removed client \"192.168.0.2\"") } th.mu.RUnlock() throttle5, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle5(ctx) } func TestThrottler_ExpirePartial(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -260,32 +229,27 @@ func TestThrottler_ExpirePartial(t *testing.T) { ctx := context.Background() throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle1(ctx) timing.now = timing.now.Add(time.Minute) throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle2(ctx) timing.now = timing.now.Add(maxBruteforceAge).Add(-time.Minute + time.Second) throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle3(ctx) } func TestThrottler_ExpireAll(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -297,32 +261,27 @@ func TestThrottler_ExpireAll(t *testing.T) { ctx := context.Background() throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle1(ctx) timing.now = timing.now.Add(time.Millisecond) throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 200 * time.Millisecond throttle2(ctx) timing.now = timing.now.Add(maxBruteforceAge).Add(time.Second) throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil { - t.Error(err) - } + assert.NoError(err) timing.expectedSleep = 100 * time.Millisecond throttle3(ctx) } func TestThrottler_Negative(t *testing.T) { + assert := assert.New(t) timing := &throttlerTiming{ t: t, now: time.Now(), @@ -336,8 +295,8 @@ func TestThrottler_Negative(t *testing.T) { for i := 0; i < maxBruteforceAttempts*10; i++ { timing.now = timing.now.Add(time.Millisecond) throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") - if err != nil && err != ErrBruteforceDetected { - t.Error(err) + if err != nil { + assert.ErrorIs(err, ErrBruteforceDetected) } if i == 0 { timing.expectedSleep = 100 * time.Millisecond diff --git a/transient_data_test.go b/transient_data_test.go index 2f856cc4..66ac3d64 100644 --- a/transient_data_test.go +++ b/transient_data_test.go @@ -25,6 +25,9 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (t *TransientData) SetTTLChannel(ch chan<- struct{}) { @@ -35,106 +38,57 @@ func (t *TransientData) SetTTLChannel(ch chan<- struct{}) { } func Test_TransientData(t *testing.T) { + assert := assert.New(t) data := NewTransientData() - if data.Set("foo", nil) { - t.Errorf("should not have set value") - } - if !data.Set("foo", "bar") { - t.Errorf("should have set value") - } - if data.Set("foo", "bar") { - t.Errorf("should not have set value") - } - if !data.Set("foo", "baz") { - t.Errorf("should have set value") - } - if data.CompareAndSet("foo", "bar", "lala") { - t.Errorf("should not have set value") - } - if !data.CompareAndSet("foo", "baz", "lala") { - t.Errorf("should have set value") - } - if data.CompareAndSet("test", nil, nil) { - t.Errorf("should not have set value") - } - if !data.CompareAndSet("test", nil, "123") { - t.Errorf("should have set value") - } - if data.CompareAndSet("test", nil, "456") { - t.Errorf("should not have set value") - } - if data.CompareAndRemove("test", "1234") { - t.Errorf("should not have removed value") - } - if !data.CompareAndRemove("test", "123") { - t.Errorf("should have removed value") - } - if data.Remove("lala") { - t.Errorf("should not have removed value") - } - if !data.Remove("foo") { - t.Errorf("should have removed value") - } + assert.False(data.Set("foo", nil)) + assert.True(data.Set("foo", "bar")) + assert.False(data.Set("foo", "bar")) + assert.True(data.Set("foo", "baz")) + assert.False(data.CompareAndSet("foo", "bar", "lala")) + assert.True(data.CompareAndSet("foo", "baz", "lala")) + assert.False(data.CompareAndSet("test", nil, nil)) + assert.True(data.CompareAndSet("test", nil, "123")) + assert.False(data.CompareAndSet("test", nil, "456")) + assert.False(data.CompareAndRemove("test", "1234")) + assert.True(data.CompareAndRemove("test", "123")) + assert.False(data.Remove("lala")) + assert.True(data.Remove("foo")) ttlCh := make(chan struct{}) data.SetTTLChannel(ttlCh) - if !data.SetTTL("test", "1234", time.Millisecond) { - t.Errorf("should have set value") - } - if value := data.GetData()["test"]; value != "1234" { - t.Errorf("expected 1234, got %v", value) - } + assert.True(data.SetTTL("test", "1234", time.Millisecond)) + assert.Equal("1234", data.GetData()["test"]) // Data is removed after the TTL <-ttlCh - if value := data.GetData()["test"]; value != nil { - t.Errorf("expected no value, got %v", value) - } + assert.Nil(data.GetData()["test"]) - if !data.SetTTL("test", "1234", time.Millisecond) { - t.Errorf("should have set value") - } - if value := data.GetData()["test"]; value != "1234" { - t.Errorf("expected 1234, got %v", value) - } - if !data.SetTTL("test", "2345", 3*time.Millisecond) { - t.Errorf("should have set value") - } - if value := data.GetData()["test"]; value != "2345" { - t.Errorf("expected 2345, got %v", value) - } + assert.True(data.SetTTL("test", "1234", time.Millisecond)) + assert.Equal("1234", data.GetData()["test"]) + assert.True(data.SetTTL("test", "2345", 3*time.Millisecond)) + assert.Equal("2345", data.GetData()["test"]) // Data is removed after the TTL only if the value still matches time.Sleep(2 * time.Millisecond) - if value := data.GetData()["test"]; value != "2345" { - t.Errorf("expected 2345, got %v", value) - } + assert.Equal("2345", data.GetData()["test"]) // Data is removed after the (second) TTL <-ttlCh - if value := data.GetData()["test"]; value != nil { - t.Errorf("expected no value, got %v", value) - } + assert.Nil(data.GetData()["test"]) // Setting existing key will update the TTL - if !data.SetTTL("test", "1234", time.Millisecond) { - t.Errorf("should have set value") - } - if data.SetTTL("test", "1234", 3*time.Millisecond) { - t.Errorf("should not have set value") - } + assert.True(data.SetTTL("test", "1234", time.Millisecond)) + assert.False(data.SetTTL("test", "1234", 3*time.Millisecond)) // Data still exists after the first TTL time.Sleep(2 * time.Millisecond) - if value := data.GetData()["test"]; value != "1234" { - t.Errorf("expected 1234, got %v", value) - } + assert.Equal("1234", data.GetData()["test"]) // Data is removed after the (updated) TTL <-ttlCh - if value := data.GetData()["test"]; value != nil { - t.Errorf("expected no value, got %v", value) - } + assert.Nil(data.GetData()["test"]) } func Test_TransientMessages(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -142,226 +96,139 @@ func Test_TransientMessages(t *testing.T) { client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() - if err := client1.SendHello(testDefaultUserId + "1"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SendHello(testDefaultUserId + "1")) hello1, err := client1.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if err := client1.SetTransientData("foo", "bar", 0); err != nil { - t.Fatal(err) - } - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_in_room"); err != nil { - t.Fatal(err) - } + require.NoError(client1.SetTransientData("foo", "bar", 0)) + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageError(msg, "not_in_room")) } client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) hello2, err := client2.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Join room by id. roomId := "test-room" - if room, err := client1.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Give message processing some time. time.Sleep(10 * time.Millisecond) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) - if session1 == nil { - t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) - } + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) - if session2 == nil { - t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) - } + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) // Client 1 may modify transient data. session1.SetPermissions([]Permission{PERMISSION_TRANSIENT_DATA}) // Client 2 may not modify transient data. session2.SetPermissions([]Permission{}) - if err := client2.SetTransientData("foo", "bar", 0); err != nil { - t.Fatal(err) - } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SetTransientData("foo", "bar", 0)) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageError(msg, "not_allowed")) } - if err := client1.SetTransientData("foo", "bar", 0); err != nil { - t.Fatal(err) - } + require.NoError(client1.SetTransientData("foo", "bar", 0)) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil { - t.Fatal(err) - } + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientSet(msg, "foo", "bar", nil)) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil { - t.Fatal(err) - } + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientSet(msg, "foo", "bar", nil)) } - if err := client2.RemoveTransientData("foo"); err != nil { - t.Fatal(err) - } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "not_allowed"); err != nil { - t.Fatal(err) - } + require.NoError(client2.RemoveTransientData("foo")) + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageError(msg, "not_allowed")) } // Setting the same value is ignored by the server. - if err := client1.SetTransientData("foo", "bar", 0); err != nil { - t.Fatal(err) - } + require.NoError(client1.SetTransientData("foo", "bar", 0)) ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if msg, err := client1.RunUntilMessage(ctx2); err != nil { - if err != context.DeadlineExceeded { - t.Fatal(err) - } + if msg, err := client1.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no payload, got %+v", msg) } else { - t.Errorf("Expected no payload, got %+v", msg) + require.ErrorIs(err, context.DeadlineExceeded) } data := map[string]interface{}{ "hello": "world", } - if err := client1.SetTransientData("foo", data, 0); err != nil { - t.Fatal(err) - } + require.NoError(client1.SetTransientData("foo", data, 0)) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil { - t.Fatal(err) - } + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientSet(msg, "foo", data, "bar")) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil { - t.Fatal(err) - } + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientSet(msg, "foo", data, "bar")) } - if err := client1.RemoveTransientData("foo"); err != nil { - t.Fatal(err) - } + require.NoError(client1.RemoveTransientData("foo")) - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientRemove(msg, "foo", data); err != nil { - t.Fatal(err) - } + if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientRemove(msg, "foo", data)) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageTransientRemove(msg, "foo", data); err != nil { - t.Fatal(err) - } + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientRemove(msg, "foo", data)) } // Removing a non-existing key is ignored by the server. - if err := client1.RemoveTransientData("foo"); err != nil { - t.Fatal(err) - } + require.NoError(client1.RemoveTransientData("foo")) ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel3() - if msg, err := client1.RunUntilMessage(ctx3); err != nil { - if err != context.DeadlineExceeded { - t.Fatal(err) - } + if msg, err := client1.RunUntilMessage(ctx3); err == nil { + assert.Fail("Expected no payload, got %+v", msg) } else { - t.Errorf("Expected no payload, got %+v", msg) + require.ErrorIs(err, context.DeadlineExceeded) } - if err := client1.SetTransientData("abc", data, 10*time.Millisecond); err != nil { - t.Fatal(err) - } + require.NoError(client1.SetTransientData("abc", data, 10*time.Millisecond)) client3 := NewTestClient(t, server, hub) defer client3.CloseWithBye() - if err := client3.SendHello(testDefaultUserId + "3"); err != nil { - t.Fatal(err) - } + require.NoError(client3.SendHello(testDefaultUserId + "3")) hello3, err := client3.RunUntilHello(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - if room, err := client3.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client3.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) _, ignored, err := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello) - if err != nil { - t.Fatal(err) - } + require.NoError(err) var msg *ServerMessage if len(ignored) == 0 { - if msg, err = client3.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } + msg, err = client3.RunUntilMessage(ctx) + require.NoError(err) } else if len(ignored) == 1 { msg = ignored[0] } else { - t.Fatalf("Received too many messages: %+v", ignored) + require.Fail("Received too many messages: %+v", ignored) } - if err := checkMessageTransientInitial(msg, map[string]interface{}{ + require.NoError(checkMessageTransientInitial(msg, map[string]interface{}{ "abc": data, - }); err != nil { - t.Fatal(err) - } + })) time.Sleep(10 * time.Millisecond) - if msg, err = client3.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else if err := checkMessageTransientRemove(msg, "abc", data); err != nil { - t.Fatal(err) + if msg, err = client3.RunUntilMessage(ctx); assert.NoError(err) { + require.NoError(checkMessageTransientRemove(msg, "abc", data)) } } diff --git a/virtualsession_test.go b/virtualsession_test.go index e351564f..bc9ca103 100644 --- a/virtualsession_test.go +++ b/virtualsession_test.go @@ -27,11 +27,16 @@ import ( "errors" "fmt" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestVirtualSession(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) roomId := "the-room-id" @@ -41,54 +46,34 @@ func TestVirtualSession(t *testing.T) { compat: true, } room, err := hub.createRoom(roomId, emptyProperties, backend) - if err != nil { - t.Fatalf("Could not create room: %s", err) - } + require.NoError(err) defer room.Close() clientInternal := NewTestClient(t, server, hub) defer clientInternal.CloseWithBye() - if err := clientInternal.SendHelloInternal(); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.SendHelloInternal()) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := clientInternal.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != "" { - t.Errorf("Expected empty user id, got %+v", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } + if hello, err := clientInternal.RunUntilHello(ctx); assert.NoError(err) { + assert.Empty(hello.Hello.UserId) + assert.NotEmpty(hello.Hello.SessionId) + assert.NotEmpty(hello.Hello.ResumeId) } hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } + assert.NoError(err) - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Ignore "join" events. - if err := client.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.DrainMessages(ctx)) internalSessionId := "session1" userId := "user1" @@ -106,64 +91,39 @@ func TestVirtualSession(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgAdd); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgAdd)) msg1, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // The public session id will be generated by the server, so don't check for it. - if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil { - t.Fatal(err) - } + require.NoError(client.checkMessageJoinedSession(msg1, "", userId)) sessionId := msg1.Event.Join[0].SessionId session := hub.GetSessionByPublicId(sessionId) - if session == nil { - t.Fatalf("Could not get virtual session %s", sessionId) - } - if session.ClientType() != HelloClientTypeVirtual { - t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType()) - } - if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId { - t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid) + if assert.NotNil(session, "Could not get virtual session %s", sessionId) { + assert.Equal(HelloClientTypeVirtual, session.ClientType()) + sid := session.(*VirtualSession).SessionId() + assert.Equal(internalSessionId, sid) } // Also a participants update event will be triggered for the virtual user. msg2, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - updateMsg, err := checkMessageParticipantsInCall(msg2) - if err != nil { - t.Error(err) - } else if updateMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId) - } else if len(updateMsg.Users) != 1 { - t.Errorf("Expected one user, got %+v", updateMsg.Users) - } else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId { - t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0]) - } else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual { - t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0]) - } else if inCall, ok := updateMsg.Users[0]["inCall"].(float64); !ok || inCall != (FlagInCall|FlagWithPhone) { - t.Errorf("Expected user in call with phone, got %+v", updateMsg.Users[0]) + require.NoError(err) + if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) { + assert.Equal(roomId, updateMsg.RoomId) + if assert.Len(updateMsg.Users, 1) { + assert.Equal(sessionId, updateMsg.Users[0]["sessionId"]) + assert.Equal(true, updateMsg.Users[0]["virtual"]) + assert.EqualValues((FlagInCall | FlagWithPhone), updateMsg.Users[0]["inCall"]) + } } msg3, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - flagsMsg, err := checkMessageParticipantFlags(msg3) - if err != nil { - t.Error(err) - } else if flagsMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId) - } else if flagsMsg.SessionId != sessionId { - t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId) - } else if flagsMsg.Flags != FLAG_MUTED_SPEAKING { - t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags) + if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) { + assert.Equal(roomId, flagsMsg.RoomId) + assert.Equal(sessionId, flagsMsg.SessionId) + assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags) } newFlags := uint32(FLAG_TALKING) @@ -180,49 +140,35 @@ func TestVirtualSession(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgFlags); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgFlags)) msg4, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - flagsMsg, err = checkMessageParticipantFlags(msg4) - if err != nil { - t.Error(err) - } else if flagsMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId) - } else if flagsMsg.SessionId != sessionId { - t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId) - } else if flagsMsg.Flags != newFlags { - t.Errorf("Expected flags %d, got %+v", newFlags, flagsMsg.Flags) + if flagsMsg, err := checkMessageParticipantFlags(msg4); assert.NoError(err) { + assert.Equal(roomId, flagsMsg.RoomId) + assert.Equal(sessionId, flagsMsg.SessionId) + assert.EqualValues(newFlags, flagsMsg.Flags) } // A new client will receive the initial flags of the virtual session. client2 := NewTestClient(t, server, hub) defer client2.CloseWithBye() - if err := client2.SendHello(testDefaultUserId + "2"); err != nil { - t.Fatal(err) - } + require.NoError(client2.SendHello(testDefaultUserId + "2")) - if _, err := client2.RunUntilHello(ctx); err != nil { - t.Error(err) - } + _, err = client2.RunUntilHello(ctx) + require.NoError(err) - if room, err := client2.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) gotFlags := false var receivedMessages []*ServerMessage for !gotFlags { messages, err := client2.GetPendingMessages(ctx) if err != nil { - t.Error(err) + assert.NoError(err) if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { break } @@ -234,26 +180,18 @@ func TestVirtualSession(t *testing.T) { continue } - if msg.Event.Flags.RoomId != roomId { - t.Errorf("Expected flags in room %s, got %s", roomId, msg.Event.Flags.RoomId) - } else if msg.Event.Flags.SessionId != sessionId { - t.Errorf("Expected flags for session %s, got %s", sessionId, msg.Event.Flags.SessionId) - } else if msg.Event.Flags.Flags != newFlags { - t.Errorf("Expected flags %d, got %d", newFlags, msg.Event.Flags.Flags) - } else { + if assert.Equal(roomId, msg.Event.Flags.RoomId) && + assert.Equal(sessionId, msg.Event.Flags.SessionId) && + assert.EqualValues(newFlags, msg.Event.Flags.Flags) { gotFlags = true break } } } - if !gotFlags { - t.Errorf("Didn't receive initial flags in %+v", receivedMessages) - } + assert.True(gotFlags, "Didn't receive initial flags in %+v", receivedMessages) // Ignore "join" messages from second client - if err := client.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.DrainMessages(ctx)) // When sending to a virtual session, the message is sent to the actual // client and contains a "Recipient" block with the internal session id. @@ -263,32 +201,21 @@ func TestVirtualSession(t *testing.T) { } data := "from-client-to-virtual" - if err := client.SendMessage(recipient, data); err != nil { - t.Fatal(err) - } + require.NoError(client.SendMessage(recipient, data)) msg2, err = clientInternal.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } else if err := checkMessageType(msg2, "message"); err != nil { - t.Fatal(err) - } else if err := checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello); err != nil { - t.Error(err) - } + require.NoError(err) + require.NoError(checkMessageType(msg2, "message")) + require.NoError(checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello)) - if msg2.Message.Recipient == nil { - t.Errorf("Expected recipient, got none") - } else if msg2.Message.Recipient.Type != "session" { - t.Errorf("Expected recipient type session, got %s", msg2.Message.Recipient.Type) - } else if msg2.Message.Recipient.SessionId != internalSessionId { - t.Errorf("Expected recipient %s, got %s", internalSessionId, msg2.Message.Recipient.SessionId) + if assert.NotNil(msg2.Message.Recipient) { + assert.Equal("session", msg2.Message.Recipient.Type) + assert.Equal(internalSessionId, msg2.Message.Recipient.SessionId) } var payload string - if err := json.Unmarshal(msg2.Message.Data, &payload); err != nil { - t.Error(err) - } else if payload != data { - t.Errorf("Expected payload %s, got %s", data, payload) + if err := json.Unmarshal(msg2.Message.Data, &payload); assert.NoError(err) { + assert.Equal(data, payload) } msgRemove := &ClientMessage{ @@ -303,16 +230,10 @@ func TestVirtualSession(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgRemove); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgRemove)) - msg5, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - if err := client.checkMessageRoomLeaveSession(msg5, sessionId); err != nil { - t.Error(err) + if msg5, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.NoError(client.checkMessageRoomLeaveSession(msg5, sessionId)) } } @@ -342,6 +263,8 @@ func checkHasEntryWithInCall(message *RoomEventServerMessage, sessionId string, func TestVirtualSessionCustomInCall(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) roomId := "the-room-id" @@ -351,9 +274,7 @@ func TestVirtualSessionCustomInCall(t *testing.T) { compat: true, } room, err := hub.createRoom(roomId, emptyProperties, backend) - if err != nil { - t.Fatalf("Could not create room: %s", err) - } + require.NoError(err) defer room.Close() clientInternal := NewTestClient(t, server, hub) @@ -361,67 +282,40 @@ func TestVirtualSessionCustomInCall(t *testing.T) { features := []string{ ClientFeatureInternalInCall, } - if err := clientInternal.SendHelloInternalWithFeatures(features); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.SendHelloInternalWithFeatures(features)) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() helloInternal, err := clientInternal.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } else { - if helloInternal.Hello.UserId != "" { - t.Errorf("Expected empty user id, got %+v", helloInternal.Hello) - } - if helloInternal.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", helloInternal.Hello) - } - if helloInternal.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", helloInternal.Hello) - } - } - if room, err := clientInternal.JoinRoomWithRoomSession(ctx, roomId, ""); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + if assert.NoError(err) { + assert.Empty(helloInternal.Hello.UserId) + assert.NotEmpty(helloInternal.Hello.SessionId) + assert.NotEmpty(helloInternal.Hello.ResumeId) } + roomMsg, err := clientInternal.JoinRoomWithRoomSession(ctx, roomId, "") + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) hello, err := client.RunUntilHello(ctx) - if err != nil { - t.Error(err) - } - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } - - if _, additional, err := clientInternal.RunUntilJoinedAndReturn(ctx, helloInternal.Hello, hello.Hello); err != nil { - t.Error(err) - } else if len(additional) != 1 { - t.Errorf("expected one additional message, got %+v", additional) - } else if additional[0].Type != "event" { - t.Errorf("expected event message, got %+v", additional[0]) - } else if additional[0].Event.Target != "participants" { - t.Errorf("expected event participants message, got %+v", additional[0]) - } else if additional[0].Event.Type != "update" { - t.Errorf("expected event participants update message, got %+v", additional[0]) - } else if additional[0].Event.Update.Users[0]["sessionId"].(string) != helloInternal.Hello.SessionId { - t.Errorf("expected event update message for internal session, got %+v", additional[0]) - } else if additional[0].Event.Update.Users[0]["inCall"].(float64) != 0 { - t.Errorf("expected event update message with session not in call, got %+v", additional[0]) - } - if err := client.RunUntilJoined(ctx, helloInternal.Hello, hello.Hello); err != nil { - t.Error(err) + assert.NoError(err) + roomMsg, err = client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + + if _, additional, err := clientInternal.RunUntilJoinedAndReturn(ctx, helloInternal.Hello, hello.Hello); assert.NoError(err) { + if assert.Len(additional, 1) && assert.Equal("event", additional[0].Type) { + assert.Equal("participants", additional[0].Event.Target) + assert.Equal("update", additional[0].Event.Type) + assert.Equal(helloInternal.Hello.SessionId, additional[0].Event.Update.Users[0]["sessionId"]) + assert.EqualValues(0, additional[0].Event.Update.Users[0]["inCall"]) + } } + assert.NoError(client.RunUntilJoined(ctx, helloInternal.Hello, hello.Hello)) internalSessionId := "session1" userId := "user1" @@ -439,65 +333,38 @@ func TestVirtualSessionCustomInCall(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgAdd); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgAdd)) msg1, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // The public session id will be generated by the server, so don't check for it. - if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil { - t.Fatal(err) - } + require.NoError(client.checkMessageJoinedSession(msg1, "", userId)) sessionId := msg1.Event.Join[0].SessionId session := hub.GetSessionByPublicId(sessionId) - if session == nil { - t.Fatalf("Could not get virtual session %s", sessionId) - } - if session.ClientType() != HelloClientTypeVirtual { - t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType()) - } - if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId { - t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid) + if assert.NotNil(session) { + assert.Equal(HelloClientTypeVirtual, session.ClientType()) + sid := session.(*VirtualSession).SessionId() + assert.Equal(internalSessionId, sid) } // Also a participants update event will be triggered for the virtual user. msg2, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - updateMsg, err := checkMessageParticipantsInCall(msg2) - if err != nil { - t.Error(err) - } else if updateMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId) - } else if len(updateMsg.Users) != 2 { - t.Errorf("Expected two users, got %+v", updateMsg.Users) - } + require.NoError(err) + if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) { + assert.Equal(roomId, updateMsg.RoomId) + assert.Len(updateMsg.Users, 2) - if err := checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0); err != nil { - t.Error(err) - } - if err := checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", 0); err != nil { - t.Error(err) + assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0)) + assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", 0)) } msg3, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - flagsMsg, err := checkMessageParticipantFlags(msg3) - if err != nil { - t.Error(err) - } else if flagsMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId) - } else if flagsMsg.SessionId != sessionId { - t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId) - } else if flagsMsg.Flags != FLAG_MUTED_SPEAKING { - t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags) + if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) { + assert.Equal(roomId, flagsMsg.RoomId) + assert.Equal(sessionId, flagsMsg.SessionId) + assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags) } // The internal session can change its "inCall" flags @@ -510,27 +377,15 @@ func TestVirtualSessionCustomInCall(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgInCall); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgInCall)) msg4, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - updateMsg2, err := checkMessageParticipantsInCall(msg4) - if err != nil { - t.Error(err) - } else if updateMsg2.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, updateMsg2.RoomId) - } else if len(updateMsg2.Users) != 2 { - t.Errorf("Expected two users, got %+v", updateMsg2.Users) - } - if err := checkHasEntryWithInCall(updateMsg2, sessionId, "virtual", 0); err != nil { - t.Error(err) - } - if err := checkHasEntryWithInCall(updateMsg2, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio); err != nil { - t.Error(err) + require.NoError(err) + if updateMsg, err := checkMessageParticipantsInCall(msg4); assert.NoError(err) { + assert.Equal(roomId, updateMsg.RoomId) + assert.Len(updateMsg.Users, 2) + assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0)) + assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio)) } // The internal session can change the "inCall" flags of a virtual session @@ -548,33 +403,23 @@ func TestVirtualSessionCustomInCall(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgInCall2); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgInCall2)) msg5, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - updateMsg3, err := checkMessageParticipantsInCall(msg5) - if err != nil { - t.Error(err) - } else if updateMsg3.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, updateMsg3.RoomId) - } else if len(updateMsg3.Users) != 2 { - t.Errorf("Expected two users, got %+v", updateMsg3.Users) - } - if err := checkHasEntryWithInCall(updateMsg3, sessionId, "virtual", newInCall); err != nil { - t.Error(err) - } - if err := checkHasEntryWithInCall(updateMsg3, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio); err != nil { - t.Error(err) + require.NoError(err) + if updateMsg, err := checkMessageParticipantsInCall(msg5); assert.NoError(err) { + assert.Equal(roomId, updateMsg.RoomId) + assert.Len(updateMsg.Users, 2) + assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", newInCall)) + assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio)) } } func TestVirtualSessionCleanup(t *testing.T) { t.Parallel() CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) roomId := "the-room-id" @@ -584,53 +429,34 @@ func TestVirtualSessionCleanup(t *testing.T) { compat: true, } room, err := hub.createRoom(roomId, emptyProperties, backend) - if err != nil { - t.Fatalf("Could not create room: %s", err) - } + require.NoError(err) defer room.Close() clientInternal := NewTestClient(t, server, hub) defer clientInternal.CloseWithBye() - if err := clientInternal.SendHelloInternal(); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.SendHelloInternal()) client := NewTestClient(t, server, hub) defer client.CloseWithBye() - if err := client.SendHello(testDefaultUserId); err != nil { - t.Fatal(err) - } + require.NoError(client.SendHello(testDefaultUserId)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if hello, err := clientInternal.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != "" { - t.Errorf("Expected empty user id, got %+v", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - if hello.Hello.ResumeId == "" { - t.Errorf("Expected resume id, got %+v", hello.Hello) - } - } - if _, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) + if hello, err := clientInternal.RunUntilHello(ctx); assert.NoError(err) { + assert.Empty(hello.Hello.UserId) + assert.NotEmpty(hello.Hello.SessionId) + assert.NotEmpty(hello.Hello.ResumeId) } + _, err = client.RunUntilHello(ctx) + assert.NoError(err) - if room, err := client.JoinRoom(ctx, roomId); err != nil { - t.Fatal(err) - } else if room.Room.RoomId != roomId { - t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) - } + roomMsg, err := client.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) // Ignore "join" events. - if err := client.DrainMessages(ctx); err != nil { - t.Error(err) - } + assert.NoError(client.DrainMessages(ctx)) internalSessionId := "session1" userId := "user1" @@ -648,72 +474,45 @@ func TestVirtualSessionCleanup(t *testing.T) { }, }, } - if err := clientInternal.WriteJSON(msgAdd); err != nil { - t.Fatal(err) - } + require.NoError(clientInternal.WriteJSON(msgAdd)) msg1, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // The public session id will be generated by the server, so don't check for it. - if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil { - t.Fatal(err) - } + require.NoError(client.checkMessageJoinedSession(msg1, "", userId)) sessionId := msg1.Event.Join[0].SessionId session := hub.GetSessionByPublicId(sessionId) - if session == nil { - t.Fatalf("Could not get virtual session %s", sessionId) - } - if session.ClientType() != HelloClientTypeVirtual { - t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType()) - } - if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId { - t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid) + if assert.NotNil(session) { + assert.Equal(HelloClientTypeVirtual, session.ClientType()) + sid := session.(*VirtualSession).SessionId() + assert.Equal(internalSessionId, sid) } // Also a participants update event will be triggered for the virtual user. msg2, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } - updateMsg, err := checkMessageParticipantsInCall(msg2) - if err != nil { - t.Error(err) - } else if updateMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId) - } else if len(updateMsg.Users) != 1 { - t.Errorf("Expected one user, got %+v", updateMsg.Users) - } else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId { - t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0]) - } else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual { - t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0]) - } else if inCall, ok := updateMsg.Users[0]["inCall"].(float64); !ok || inCall != (FlagInCall|FlagWithPhone) { - t.Errorf("Expected user in call with phone, got %+v", updateMsg.Users[0]) + require.NoError(err) + if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) { + assert.Equal(roomId, updateMsg.RoomId) + if assert.Len(updateMsg.Users, 1) { + assert.Equal(sessionId, updateMsg.Users[0]["sessionId"]) + assert.Equal(true, updateMsg.Users[0]["virtual"]) + assert.EqualValues((FlagInCall | FlagWithPhone), updateMsg.Users[0]["inCall"]) + } } msg3, err := client.RunUntilMessage(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(err) - flagsMsg, err := checkMessageParticipantFlags(msg3) - if err != nil { - t.Error(err) - } else if flagsMsg.RoomId != roomId { - t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId) - } else if flagsMsg.SessionId != sessionId { - t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId) - } else if flagsMsg.Flags != FLAG_MUTED_SPEAKING { - t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags) + if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) { + assert.Equal(roomId, flagsMsg.RoomId) + assert.Equal(sessionId, flagsMsg.SessionId) + assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags) } // The virtual sessions are closed when the parent session is deleted. clientInternal.CloseWithBye() - if msg2, err := client.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else if err := client.checkMessageRoomLeaveSession(msg2, sessionId); err != nil { - t.Error(err) - } + msg2, err = client.RunUntilMessage(ctx) + require.NoError(err) + assert.NoError(client.checkMessageRoomLeaveSession(msg2, sessionId)) }