Skip to content

Commit

Permalink
Don't update capabilities concurrently from same host.
Browse files Browse the repository at this point in the history
If capabilities are expired and requested from multiple clients concurrently,
this could cause concurrent (duplicate) requests to the same Nextcloud host.
With this change, only a single request is sent to Nextcloud in such cases.
  • Loading branch information
fancycode committed Oct 9, 2024
1 parent d692a3b commit 8f1ec78
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 52 deletions.
111 changes: 59 additions & 52 deletions capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,32 @@ var (
)

type capabilitiesEntry struct {
c *Capabilities
mu sync.RWMutex
nextUpdate time.Time
etag string
mustRevalidate bool
capabilities map[string]interface{}
}

func newCapabilitiesEntry() *capabilitiesEntry {
return &capabilitiesEntry{}
func newCapabilitiesEntry(c *Capabilities) *capabilitiesEntry {
return &capabilitiesEntry{
c: c,
}
}

func (e *capabilitiesEntry) valid(now time.Time) bool {
e.mu.RLock()
defer e.mu.RUnlock()

return e.validLocked(now)
}

func (e *capabilitiesEntry) validLocked(now time.Time) bool {
return e.nextUpdate.After(now)
}

func (e *capabilitiesEntry) updateRequest(r *http.Request) {
e.mu.RLock()
defer e.mu.RUnlock()

if e.etag != "" {
r.Header.Set("If-None-Match", e.etag)
}
Expand All @@ -103,10 +107,50 @@ func (e *capabilitiesEntry) errorIfMustRevalidate(err error) error {
return err
}

func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error {
func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Time) error {
e.mu.Lock()
defer e.mu.Unlock()

if e.validLocked(now) {
// Capabilities were updated while waiting for the lock.
return nil
}

capUrl := *u
if !strings.Contains(capUrl.Path, "ocs/v2.php") {
if !strings.HasSuffix(capUrl.Path, "/") {
capUrl.Path += "/"
}
capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities"
} else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 {
capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities"
}

log.Printf("Capabilities expired for %s, updating", capUrl.String())

client, pool, err := e.c.pool.Get(ctx, &capUrl)
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return err
}
defer pool.Put(client)

req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+e.c.version)
e.updateRequest(req)

response, err := client.Do(req)
if err != nil {
return err
}
defer response.Body.Close()

url := response.Request.URL
e.etag = response.Header.Get("ETag")

Expand Down Expand Up @@ -231,11 +275,15 @@ func (c *Capabilities) getCapabilities(key string) (*capabilitiesEntry, bool) {

now := c.getNow()
entry, found := c.entries[key]
if found && entry.valid(now) {
return entry, true
if !found {
// Upgrade to write-lock
c.mu.RUnlock()
defer c.mu.RLock()

entry = c.newCapabilitiesEntry(key)
}

return entry, false
return entry, entry.valid(now)
}

func (c *Capabilities) invalidateCapabilities(key string) {
Expand All @@ -260,7 +308,7 @@ func (c *Capabilities) newCapabilitiesEntry(key string) *capabilitiesEntry {

entry, found := c.entries[key]
if !found {
entry = newCapabilitiesEntry()
entry = newCapabilitiesEntry(c)
c.entries[key] = entry
}

Expand All @@ -279,48 +327,7 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
return entry.GetCapabilities(), true, nil
}

capUrl := *u
if !strings.Contains(capUrl.Path, "ocs/v2.php") {
if !strings.HasSuffix(capUrl.Path, "/") {
capUrl.Path += "/"
}
capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities"
} else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 {
capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities"
}

log.Printf("Capabilities expired for %s, updating", capUrl.String())

client, pool, err := c.pool.Get(ctx, &capUrl)
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return nil, false, err
}
defer pool.Put(client)

req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return nil, false, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+c.version)
if entry != nil {
entry.updateRequest(req)
}

resp, err := client.Do(req)
if err != nil {
return nil, false, err
}
defer resp.Body.Close()

if entry == nil {
entry = c.newCapabilitiesEntry(key)
}

if err := entry.update(resp, c.getNow()); err != nil {
if err := entry.update(ctx, u, c.getNow()); err != nil {
return nil, false, err
}

Expand Down
53 changes: 53 additions & 0 deletions capabilities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -528,3 +529,55 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
value = called.Load()
assert.EqualValues(2, value)
}

func TestConcurrentExpired(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
called.Add(1)
return nil
})

ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}

count := 100
start := make(chan struct{}, 0)

Check failure on line 553 in capabilities_test.go

View workflow job for this annotation

GitHub Actions / golang

S1019: should use make(chan struct{}) instead (gosimple)
var numCached atomic.Uint32
var numFetched atomic.Uint32
var finished sync.WaitGroup
for i := 0; i < count; i++ {
finished.Add(1)
go func() {
defer finished.Done()
<-start
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
if cached {
numCached.Add(1)
} else {
numFetched.Add(1)
}
}
}()
}

SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
})

close(start)
finished.Wait()

assert.EqualValues(2, called.Load())
assert.EqualValues(1, numFetched.Load())
assert.EqualValues(count-1, numCached.Load())
}

0 comments on commit 8f1ec78

Please sign in to comment.