diff --git a/cmd/outline-ss-server/source.go b/cmd/outline-ss-server/source.go index 74413182..2421159d 100644 --- a/cmd/outline-ss-server/source.go +++ b/cmd/outline-ss-server/source.go @@ -16,7 +16,7 @@ type Source struct { type KeyUpdater struct { Ciphers service.CipherList - ciphersByID map[string]*service.CipherEntry + ciphersByID map[string]func() } func (c *KeyUpdater) AddKey(key key.Key) error { @@ -26,9 +26,9 @@ func (c *KeyUpdater) AddKey(key key.Key) error { } entry := service.MakeCipherEntry(key.ID, cryptoKey, key.Secret) slog.Info("Added key ", "keyID", key.ID) - c.Ciphers.AddEntry(&entry) - // Store the entry in a map for fast removal - c.ciphersByID[key.ID] = &entry + rmFunc := c.Ciphers.AddEntry(&entry) + // Store the remove callback in a map for fast removal + c.ciphersByID[key.ID] = rmFunc return nil } @@ -36,9 +36,9 @@ func (c *KeyUpdater) RemoveKey(key key.Key) error { if c.Ciphers == nil { return fmt.Errorf("no Cipher available while removing key %v", key.ID) } - entry, exists := c.ciphersByID[key.ID] + rmFunc, exists := c.ciphersByID[key.ID] if exists { - c.Ciphers.RemoveEntry(entry) + rmFunc() return nil } else { return fmt.Errorf("key %v was not found", key.ID) @@ -63,7 +63,7 @@ func newFileUpdater(c service.CipherList, s key.Source, logger *slog.Logger) *Fi fileSource: s, } fu.KeyUpdater.Ciphers = c - fu.KeyUpdater.ciphersByID = make(map[string]*service.CipherEntry) + fu.KeyUpdater.ciphersByID = make(map[string]func()) return fu } diff --git a/service/cipher_list.go b/service/cipher_list.go index 3faa9d0e..d42bc6fd 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -64,7 +64,7 @@ type CipherList interface { // which is a List of *CipherEntry. Update takes ownership of `contents`, // which must not be read or written after this call. Update(contents *list.List) - AddEntry(e *CipherEntry) + AddEntry(e *CipherEntry) func() RemoveEntry(entry *CipherEntry) } @@ -121,19 +121,11 @@ func (cl *cipherList) Update(src *list.List) { cl.mu.Unlock() } -func (cl *cipherList) AddEntry(e *CipherEntry) { +func (cl *cipherList) AddEntry(e *CipherEntry) func() { cl.mu.Lock() - cl.list.PushFront(e) - cl.mu.Unlock() -} - -func (cl *cipherList) RemoveEntry(entry *CipherEntry) { - cl.mu.Lock() - for e := cl.list.Front(); e != nil; e = e.Next() { - if e.Value.(*CipherEntry) == entry { - cl.list.Remove(e) - break - } + defer cl.mu.Unlock() + el := cl.list.PushFront(e) + return func() { + cl.list.Remove(el) } - cl.mu.Unlock() } diff --git a/service/cipher_list_test.go b/service/cipher_list_test.go index 6f6437c2..f6c99505 100644 --- a/service/cipher_list_test.go +++ b/service/cipher_list_test.go @@ -82,6 +82,29 @@ func TestAddRemoveEntries(t *testing.T) { require.Equal(t, found, false, "The entry was removed but it's still in the list.") } +func TestAddEntry(t *testing.T) { + ciphers := NewCipherList() + key, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, "testPassword") + if err != nil { + t.Fatalf("Failed to create key: %v", err) + } + entry := MakeCipherEntry("cipher1", key, "password") + removeFunc := ciphers.AddEntry(&entry) + + // Verify the entry was added + entries := ciphers.SnapshotForClientIP(netip.Addr{}) + found := contains(entries, &entry) + require.Equal(t, found, true, "Did not find entry that was added!") + + // Use the returned function to remove the entry + removeFunc() + + // Verify the entry was removed + entries = ciphers.SnapshotForClientIP(netip.Addr{}) + found = contains(entries, &entry) + require.Equal(t, found, false, "The entry was removed but it's still in the list.") +} + func contains(entries []*list.Element, entry *CipherEntry) bool { found := false for _, e := range entries {