diff --git a/cln/cln_client.go b/cln/cln_client.go index c80ffe0e..58a2bc47 100644 --- a/cln/cln_client.go +++ b/cln/cln_client.go @@ -1,6 +1,7 @@ package cln import ( + "context" "encoding/hex" "fmt" "log" @@ -259,7 +260,7 @@ func (c *ClnClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) { return hex.DecodeString(*dest) } -var pollingInterval = 400 * time.Millisecond +var onlinePollingInterval = 400 * time.Millisecond func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error { peerIDStr := hex.EncodeToString(peerID) @@ -272,7 +273,7 @@ func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error { select { case <-time.After(time.Until(deadline)): return fmt.Errorf("timeout") - case <-time.After(pollingInterval): + case <-time.After(onlinePollingInterval): } } } @@ -280,3 +281,65 @@ func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error { func (c *ClnClient) WaitChannelActive(peerID []byte, deadline time.Time) error { return nil } + +const scidPollingInterval = time.Second * 10 + +func (c *ClnClient) WatchScids( + ctx context.Context, + cache lightning.ScidCacheWriter, +) error { + ticker := time.NewTicker(scidPollingInterval) + defer ticker.Stop() + + err := c.updateScids(cache) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + err = c.updateScids(cache) + if err != nil { + return err + } + } + } +} + +func (c *ClnClient) updateScids(cache lightning.ScidCacheWriter) error { + peers, err := c.client.ListPeers() + if err != nil { + return err + } + + var scids []basetypes.ShortChannelID + for _, peer := range peers { + for _, ch := range peer.Channels { + s, err := mapToScids(ch.Alias.Local, ch.ShortChannelId) + if err != nil { + return err + } + + scids = append(scids, s...) + } + } + + cache.ReplaceScids(scids) + return nil +} + +func mapToScids(scids ...string) ([]basetypes.ShortChannelID, error) { + var result []basetypes.ShortChannelID + for _, scid := range scids { + s, err := basetypes.NewShortChannelIDFromString(scid) + if err != nil { + return nil, err + } + result = append(result, *s) + } + + return result, nil +} diff --git a/lightning/client.go b/lightning/client.go index 009e562f..fe741e29 100644 --- a/lightning/client.go +++ b/lightning/client.go @@ -1,6 +1,7 @@ package lightning import ( + "context" "time" "github.com/breez/lspd/basetypes" @@ -38,4 +39,5 @@ type Client interface { GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) WaitOnline(peerID []byte, deadline time.Time) error WaitChannelActive(peerID []byte, deadline time.Time) error + WatchScids(ctx context.Context, cache ScidCacheWriter) error } diff --git a/lightning/scid_cache.go b/lightning/scid_cache.go new file mode 100644 index 00000000..3b82841b --- /dev/null +++ b/lightning/scid_cache.go @@ -0,0 +1,72 @@ +package lightning + +import ( + "sync" + + "github.com/breez/lspd/basetypes" +) + +type ScidCacheReader interface { + ContainsScid(scid basetypes.ShortChannelID) bool +} + +type ScidCacheWriter interface { + AddScids(scids ...basetypes.ShortChannelID) + RemoveScids(scids ...basetypes.ShortChannelID) + ReplaceScids(scids []basetypes.ShortChannelID) +} + +type ScidCache struct { + activeScids map[uint64]bool + mtx sync.RWMutex +} + +func NewScidCache() *ScidCache { + return &ScidCache{ + activeScids: make(map[uint64]bool), + } +} + +func (s *ScidCache) ContainsScid( + scid basetypes.ShortChannelID, +) bool { + s.mtx.RLock() + defer s.mtx.RUnlock() + + _, ok := s.activeScids[uint64(scid)] + return ok +} + +func (s *ScidCache) AddScids( + scids ...basetypes.ShortChannelID, +) { + s.mtx.Lock() + defer s.mtx.Unlock() + + for _, scid := range scids { + s.activeScids[uint64(scid)] = true + } +} + +func (s *ScidCache) RemoveScids( + scids ...basetypes.ShortChannelID, +) { + s.mtx.Lock() + defer s.mtx.Unlock() + + for _, scid := range scids { + delete(s.activeScids, uint64(scid)) + } +} + +func (s *ScidCache) ReplaceScids( + scids []basetypes.ShortChannelID, +) { + s.mtx.Lock() + defer s.mtx.Unlock() + + s.activeScids = make(map[uint64]bool) + for _, scid := range scids { + s.activeScids[uint64(scid)] = true + } +} diff --git a/lnd/client.go b/lnd/client.go index 90ae714d..aee69729 100644 --- a/lnd/client.go +++ b/lnd/client.go @@ -365,6 +365,72 @@ func (c *LndClient) GetClosedChannels(nodeID string, channelPoints map[string]ui return r, nil } +func (c *LndClient) WatchScids( + ctx context.Context, + cache lightning.ScidCacheWriter, +) error { + stream, err := c.client.SubscribeChannelEvents( + ctx, + &lnrpc.ChannelEventSubscription{}, + ) + if err != nil { + return err + } + + // Sync current scids first + chans, err := c.client.ListChannels( + context.Background(), + &lnrpc.ListChannelsRequest{}, + ) + if err != nil { + return err + } + + var initialScids []basetypes.ShortChannelID + for _, ch := range chans.Channels { + initialScids = append( + initialScids, + mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)...) + } + cache.ReplaceScids(initialScids) + + // Watch channels for changes + for { + upd, err := stream.Recv() + if err != nil { + return err + } + + switch upd.Type { + case lnrpc.ChannelEventUpdate_OPEN_CHANNEL: + ch := upd.GetOpenChannel() + if ch == nil { + continue + } + + scids := mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...) + cache.AddScids(scids...) + case lnrpc.ChannelEventUpdate_CLOSED_CHANNEL: + ch := upd.GetClosedChannel() + if ch == nil { + continue + } + + scids := mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...) + cache.RemoveScids(scids...) + } + } +} + +func mapToScids(scids ...uint64) []basetypes.ShortChannelID { + var result []basetypes.ShortChannelID + for _, scid := range scids { + result = append(result, basetypes.ShortChannelID(scid)) + } + + return result +} + func (c *LndClient) getWaitingCloseChannels(nodeID string) ([]*lnrpc.PendingChannelsResponse_WaitingCloseChannel, error) { pendingResponse, err := c.client.PendingChannels(context.Background(), &lnrpc.PendingChannelsRequest{}) if err != nil { diff --git a/lsps2/scid_cleaner.go b/lsps2/scid_cleaner.go new file mode 100644 index 00000000..29649376 --- /dev/null +++ b/lsps2/scid_cleaner.go @@ -0,0 +1,39 @@ +package lsps2 + +import ( + "context" + "log" + "time" +) + +type ScidCleaner struct { + interval time.Duration + store ScidStore +} + +func NewScidCleaner( + store ScidStore, + interval time.Duration, +) *ScidCleaner { + return &ScidCleaner{ + interval: interval, + store: store, + } +} + +func (s *ScidCleaner) Start(ctx context.Context) error { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case t := <-ticker.C: + err := s.store.RemoveExpired(ctx, t) + if err != nil { + log.Printf("Failed to remove expired scids: %v", err) + } + } + } +} diff --git a/lsps2/scid_service.go b/lsps2/scid_service.go new file mode 100644 index 00000000..f60e20f1 --- /dev/null +++ b/lsps2/scid_service.go @@ -0,0 +1,84 @@ +package lsps2 + +import ( + "context" + "crypto/rand" + "fmt" + "log" + "math/big" + "time" + + "github.com/breez/lspd/basetypes" + "github.com/breez/lspd/lightning" +) + +const maxScidAttempts = 5 + +var two = big.NewInt(2) +var sixtyfour = big.NewInt(64) +var maxUint64 = two.Exp(two, sixtyfour, nil) + +func newScid() (*basetypes.ShortChannelID, error) { + s, err := rand.Int(rand.Reader, maxUint64) + if err != nil { + return nil, err + } + + scid := basetypes.ShortChannelID(s.Uint64()) + return &scid, nil +} + +type ScidService struct { + lspId []byte + store ScidStore + cache lightning.ScidCacheReader +} + +func NewScidService( + lspId []byte, + store ScidStore, + cache lightning.ScidCacheReader, +) *ScidService { + return &ScidService{lspId: lspId, store: store, cache: cache} +} + +func (s *ScidService) ReserveNewScid( + ctx context.Context, + expiry time.Time, +) (*basetypes.ShortChannelID, error) { + for attempts := 0; attempts < maxScidAttempts; attempts++ { + scid, err := newScid() + if err != nil { + log.Printf("NewScid() error: %v", err) + continue + } + + if s.cache.ContainsScid(*scid) { + log.Printf( + "Collision with existing channel when generating new scid %s", + scid.ToString(), + ) + continue + } + + err = s.store.AddScid(ctx, s.lspId, *scid, expiry) + if err == ErrScidExists { + log.Printf( + "Collision on inserting random new scid %s", + scid.ToString(), + ) + continue + } + + if err != nil { + return nil, fmt.Errorf("failed to insert scid reservation: %w", err) + } + + return scid, nil + } + + return nil, fmt.Errorf( + "failed to reserve scid after %v attempts", + maxScidAttempts, + ) +} diff --git a/lsps2/scid_store.go b/lsps2/scid_store.go new file mode 100644 index 00000000..3cccfa4f --- /dev/null +++ b/lsps2/scid_store.go @@ -0,0 +1,26 @@ +package lsps2 + +import ( + "context" + "fmt" + "time" + + "github.com/breez/lspd/basetypes" +) + +var ErrScidExists = fmt.Errorf("scid already exists") + +type ScidStore interface { + AddScid( + ctx context.Context, + lspId []byte, + scid basetypes.ShortChannelID, + expiry time.Time, + ) error + RemoveScid( + ctx context.Context, + lspId []byte, + scid basetypes.ShortChannelID, + ) (bool, error) + RemoveExpired(ctx context.Context, before time.Time) error +} diff --git a/postgresql/migrations/000014_lsps2_scids.down.sql b/postgresql/migrations/000014_lsps2_scids.down.sql new file mode 100644 index 00000000..b8a42220 --- /dev/null +++ b/postgresql/migrations/000014_lsps2_scids.down.sql @@ -0,0 +1,2 @@ +DROP INDEX scid_reservations_expiry_idx; +DROP TABLE public.scid_reservations; diff --git a/postgresql/migrations/000014_lsps2_scids.up.sql b/postgresql/migrations/000014_lsps2_scids.up.sql new file mode 100644 index 00000000..bde1b059 --- /dev/null +++ b/postgresql/migrations/000014_lsps2_scids.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE public.scid_reservations ( + id SERIAL PRIMARY KEY, + lspid bytea NOT NULL, + scid bigint NOT NULL, + expiry bigint NOT NULL +); + +CREATE UNIQUE INDEX scid_reservations_lspid_scid_idx ON public.scid_reservations (lspid, scid); +CREATE INDEX scid_reservations_expiry_idx ON public.scid_reservations (expiry); diff --git a/postgresql/scid_store.go b/postgresql/scid_store.go new file mode 100644 index 00000000..9454a2a2 --- /dev/null +++ b/postgresql/scid_store.go @@ -0,0 +1,88 @@ +package postgresql + +import ( + "context" + "log" + "strings" + "time" + + "github.com/breez/lspd/basetypes" + "github.com/breez/lspd/lsps2" + "github.com/jackc/pgx/v4/pgxpool" +) + +type ScidStore struct { + pool *pgxpool.Pool +} + +func NewScidStore(pool *pgxpool.Pool) *ScidStore { + return &ScidStore{ + pool: pool, + } +} + +func (s *ScidStore) AddScid( + ctx context.Context, + lspId []byte, + scid basetypes.ShortChannelID, + expiry time.Time, +) error { + _, err := s.pool.Exec( + ctx, + `INSERT INTO public.scid_reservations (lspid, scid, expiry) + VALUES (?, ?, ?)`, + lspId, + int64(uint64(scid)), // store the scid as int64 + expiry.Unix(), + ) + + if err != nil && strings.Contains(err.Error(), "already exists") { + return lsps2.ErrScidExists + } + + return err +} + +func (s *ScidStore) RemoveScid( + ctx context.Context, + lspId []byte, + scid basetypes.ShortChannelID, +) (bool, error) { + res, err := s.pool.Exec( + ctx, + `DELETE FROM public.scid_reservations + WHERE lspid = ? AND scid = ?`, + lspId, + int64(uint64(scid)), // convert scid to int64 + ) + + if err != nil { + return false, err + } + + return res.RowsAffected() > 0, nil +} + +func (s *ScidStore) RemoveExpired(ctx context.Context, before time.Time) error { + rows, err := s.pool.Exec( + ctx, + `DELETE FROM public.scid_reservations + WHERE expiry < ?`, + before.Unix(), + ) + + if err != nil { + return err + } + + rowsAffected := rows.RowsAffected() + if rowsAffected > 0 { + log.Printf( + "Deleted %d scids from scid_reservations that expired before %s", + rowsAffected, + before.String(), + ) + } + + return nil +}