Skip to content

Commit

Permalink
lsps2: framework for scid management
Browse files Browse the repository at this point in the history
There's a background watcher that scans the lightning node for in-use scids. These scids are cached in memory.
There's a table containing the scids in use for the LSP.

`ScidService.ReserveNewScid` attempts to reserve an unused scid. It tries a few times, just in case it runs into a collision.

To tie this PR together into something working:
- Initialize a `lightning.ScidCache`
- Run `LightningClient.WatchScids` on either the CLN or LND client
- Initialize a `jit.ScidCleaner` and start it
- Initialize a `jit.ScidService`
- Call `ScidService.ReserveNewScid` to reserve a scid
  • Loading branch information
JssDWt committed Aug 7, 2023
1 parent 72c2b47 commit 42ce542
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 2 deletions.
67 changes: 65 additions & 2 deletions cln/cln_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cln

import (
"context"
"encoding/hex"
"fmt"
"log"
Expand Down Expand Up @@ -259,7 +260,7 @@ func (c *ClnClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) {
return hex.DecodeString(*dest)
}

var pollingInterval = 400 * time.Millisecond
var onlinePollingInterval = 400 * time.Millisecond

func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error {
peerIDStr := hex.EncodeToString(peerID)
Expand All @@ -272,11 +273,73 @@ func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error {
select {
case <-time.After(time.Until(deadline)):
return fmt.Errorf("timeout")
case <-time.After(pollingInterval):
case <-time.After(onlinePollingInterval):
}
}
}

func (c *ClnClient) WaitChannelActive(peerID []byte, deadline time.Time) error {
return nil
}

const scidPollingInterval = time.Second * 10

func (c *ClnClient) WatchScids(
ctx context.Context,
cache lightning.ScidCacheWriter,
) error {
ticker := time.NewTicker(scidPollingInterval)
defer ticker.Stop()

err := c.updateScids(cache)
if err != nil {
return err
}

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
err = c.updateScids(cache)
if err != nil {
return err
}
}
}
}

func (c *ClnClient) updateScids(cache lightning.ScidCacheWriter) error {
peers, err := c.client.ListPeers()
if err != nil {
return err
}

var scids []basetypes.ShortChannelID
for _, peer := range peers {
for _, ch := range peer.Channels {
s, err := mapToScids(ch.Alias.Local, ch.ShortChannelId)
if err != nil {
return err
}

scids = append(scids, s...)
}
}

cache.ReplaceScids(scids)
return nil
}

func mapToScids(scids ...string) ([]basetypes.ShortChannelID, error) {
var result []basetypes.ShortChannelID
for _, scid := range scids {
s, err := basetypes.NewShortChannelIDFromString(scid)
if err != nil {
return nil, err
}
result = append(result, *s)
}

return result, nil
}
2 changes: 2 additions & 0 deletions lightning/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lightning

import (
"context"
"time"

"github.com/breez/lspd/basetypes"
Expand Down Expand Up @@ -38,4 +39,5 @@ type Client interface {
GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error)
WaitOnline(peerID []byte, deadline time.Time) error
WaitChannelActive(peerID []byte, deadline time.Time) error
WatchScids(ctx context.Context, cache ScidCacheWriter) error
}
72 changes: 72 additions & 0 deletions lightning/scid_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package lightning

import (
"sync"

"github.com/breez/lspd/basetypes"
)

type ScidCacheReader interface {
ContainsScid(scid basetypes.ShortChannelID) bool
}

type ScidCacheWriter interface {
AddScids(scids ...basetypes.ShortChannelID)
RemoveScids(scids ...basetypes.ShortChannelID)
ReplaceScids(scids []basetypes.ShortChannelID)
}

type ScidCache struct {
activeScids map[uint64]bool
mtx sync.RWMutex
}

func NewScidCache() *ScidCache {
return &ScidCache{
activeScids: make(map[uint64]bool),
}
}

func (s *ScidCache) ContainsScid(
scid basetypes.ShortChannelID,
) bool {
s.mtx.RLock()
defer s.mtx.RUnlock()

_, ok := s.activeScids[uint64(scid)]
return ok
}

func (s *ScidCache) AddScids(
scids ...basetypes.ShortChannelID,
) {
s.mtx.Lock()
defer s.mtx.Unlock()

for _, scid := range scids {
s.activeScids[uint64(scid)] = true
}
}

func (s *ScidCache) RemoveScids(
scids ...basetypes.ShortChannelID,
) {
s.mtx.Lock()
defer s.mtx.Unlock()

for _, scid := range scids {
delete(s.activeScids, uint64(scid))
}
}

func (s *ScidCache) ReplaceScids(
scids []basetypes.ShortChannelID,
) {
s.mtx.Lock()
defer s.mtx.Unlock()

s.activeScids = make(map[uint64]bool)
for _, scid := range scids {
s.activeScids[uint64(scid)] = true
}
}
66 changes: 66 additions & 0 deletions lnd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,72 @@ func (c *LndClient) GetClosedChannels(nodeID string, channelPoints map[string]ui
return r, nil
}

func (c *LndClient) WatchScids(
ctx context.Context,
cache lightning.ScidCacheWriter,
) error {
stream, err := c.client.SubscribeChannelEvents(
ctx,
&lnrpc.ChannelEventSubscription{},
)
if err != nil {
return err
}

// Sync current scids first
chans, err := c.client.ListChannels(
context.Background(),
&lnrpc.ListChannelsRequest{},
)
if err != nil {
return err
}

var initialScids []basetypes.ShortChannelID
for _, ch := range chans.Channels {
initialScids = append(
initialScids,
mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)...)
}
cache.ReplaceScids(initialScids)

// Watch channels for changes
for {
upd, err := stream.Recv()
if err != nil {
return err
}

switch upd.Type {
case lnrpc.ChannelEventUpdate_OPEN_CHANNEL:
ch := upd.GetOpenChannel()
if ch == nil {
continue
}

scids := mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)
cache.AddScids(scids...)
case lnrpc.ChannelEventUpdate_CLOSED_CHANNEL:
ch := upd.GetClosedChannel()
if ch == nil {
continue
}

scids := mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)
cache.RemoveScids(scids...)
}
}
}

func mapToScids(scids ...uint64) []basetypes.ShortChannelID {
var result []basetypes.ShortChannelID
for _, scid := range scids {
result = append(result, basetypes.ShortChannelID(scid))
}

return result
}

func (c *LndClient) getWaitingCloseChannels(nodeID string) ([]*lnrpc.PendingChannelsResponse_WaitingCloseChannel, error) {
pendingResponse, err := c.client.PendingChannels(context.Background(), &lnrpc.PendingChannelsRequest{})
if err != nil {
Expand Down
39 changes: 39 additions & 0 deletions lsps2/scid_cleaner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package lsps2

import (
"context"
"log"
"time"
)

type ScidCleaner struct {
interval time.Duration
store ScidStore
}

func NewScidCleaner(
store ScidStore,
interval time.Duration,
) *ScidCleaner {
return &ScidCleaner{
interval: interval,
store: store,
}
}

func (s *ScidCleaner) Start(ctx context.Context) error {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return ctx.Err()
case t := <-ticker.C:
err := s.store.RemoveExpired(ctx, t)
if err != nil {
log.Printf("Failed to remove expired scids: %v", err)
}
}
}
}
82 changes: 82 additions & 0 deletions lsps2/scid_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package lsps2

import (
"context"
"crypto/rand"
"fmt"
"log"
"math/big"
"time"

"github.com/breez/lspd/basetypes"
"github.com/breez/lspd/lightning"
)

const maxScidAttempts = 5

var two = big.NewInt(2)
var sixtyfour = big.NewInt(64)
var maxUint64 = two.Exp(two, sixtyfour, nil)

func newScid() (*basetypes.ShortChannelID, error) {
s, err := rand.Int(rand.Reader, maxUint64)
if err != nil {
return nil, err
}

scid := basetypes.ShortChannelID(s.Uint64())
return &scid, nil
}

type ScidService struct {
store ScidStore
cache lightning.ScidCacheReader
}

func NewScidService(
store ScidStore,
cache lightning.ScidCacheReader,
) *ScidService {
return &ScidService{store: store, cache: cache}
}

func (s *ScidService) ReserveNewScid(
ctx context.Context,
expiry time.Time,
) (*basetypes.ShortChannelID, error) {
for attempts := 0; attempts < maxScidAttempts; attempts++ {
scid, err := newScid()
if err != nil {
log.Printf("NewScid() error: %v", err)
continue
}

if s.cache.ContainsScid(*scid) {
log.Printf(
"Collision with existing channel when generating new scid %s",
scid.ToString(),
)
continue
}

err = s.store.AddScid(ctx, *scid, expiry)
if err == ErrScidExists {
log.Printf(
"Collision on inserting random new scid %s",
scid.ToString(),
)
continue
}

if err != nil {
return nil, fmt.Errorf("failed to insert scid reservation: %w", err)
}

return scid, nil
}

return nil, fmt.Errorf(
"failed to reserve scid after %v attempts",
maxScidAttempts,
)
}
20 changes: 20 additions & 0 deletions lsps2/scid_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package lsps2

import (
"context"
"fmt"
"time"

"github.com/breez/lspd/basetypes"
)

var ErrScidExists = fmt.Errorf("scid already exists")

type ScidStore interface {
AddScid(
ctx context.Context,
scid basetypes.ShortChannelID,
expiry time.Time,
) error
RemoveExpired(ctx context.Context, before time.Time) error
}
2 changes: 2 additions & 0 deletions postgresql/migrations/000014_lsps2.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DROP INDEX scid_reservations_expiry_idx;
DROP TABLE public.scid_reservations;
Loading

0 comments on commit 42ce542

Please sign in to comment.