Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lsps2: framework for scid management #112

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions cln/cln_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cln

import (
"context"
"encoding/hex"
"fmt"
"log"
Expand Down Expand Up @@ -259,7 +260,7 @@ func (c *ClnClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) {
return hex.DecodeString(*dest)
}

var pollingInterval = 400 * time.Millisecond
var onlinePollingInterval = 400 * time.Millisecond

func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error {
peerIDStr := hex.EncodeToString(peerID)
Expand All @@ -272,11 +273,73 @@ func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error {
select {
case <-time.After(time.Until(deadline)):
return fmt.Errorf("timeout")
case <-time.After(pollingInterval):
case <-time.After(onlinePollingInterval):
}
}
}

func (c *ClnClient) WaitChannelActive(peerID []byte, deadline time.Time) error {
return nil
}

const scidPollingInterval = time.Second * 10

func (c *ClnClient) WatchScids(
ctx context.Context,
cache lightning.ScidCacheWriter,
) error {
ticker := time.NewTicker(scidPollingInterval)
defer ticker.Stop()

err := c.updateScids(cache)
if err != nil {
return err
}

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
err = c.updateScids(cache)
if err != nil {
return err
}
}
}
}

func (c *ClnClient) updateScids(cache lightning.ScidCacheWriter) error {
peers, err := c.client.ListPeers()
if err != nil {
return err
}

var scids []basetypes.ShortChannelID
for _, peer := range peers {
for _, ch := range peer.Channels {
s, err := mapToScids(ch.Alias.Local, ch.ShortChannelId)
if err != nil {
return err
}

scids = append(scids, s...)
}
}

cache.ReplaceScids(scids)
return nil
}

func mapToScids(scids ...string) ([]basetypes.ShortChannelID, error) {
var result []basetypes.ShortChannelID
for _, scid := range scids {
s, err := basetypes.NewShortChannelIDFromString(scid)
if err != nil {
return nil, err
}
result = append(result, *s)
}

return result, nil
}
2 changes: 2 additions & 0 deletions lightning/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lightning

import (
"context"
"time"

"github.com/breez/lspd/basetypes"
Expand Down Expand Up @@ -38,4 +39,5 @@ type Client interface {
GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error)
WaitOnline(peerID []byte, deadline time.Time) error
WaitChannelActive(peerID []byte, deadline time.Time) error
WatchScids(ctx context.Context, cache ScidCacheWriter) error
}
72 changes: 72 additions & 0 deletions lightning/scid_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package lightning

import (
"sync"

"github.com/breez/lspd/basetypes"
)

type ScidCacheReader interface {
ContainsScid(scid basetypes.ShortChannelID) bool
}

type ScidCacheWriter interface {
AddScids(scids ...basetypes.ShortChannelID)
RemoveScids(scids ...basetypes.ShortChannelID)
ReplaceScids(scids []basetypes.ShortChannelID)
}

type ScidCache struct {
activeScids map[uint64]bool
mtx sync.RWMutex
}

func NewScidCache() *ScidCache {
return &ScidCache{
activeScids: make(map[uint64]bool),
}
}

func (s *ScidCache) ContainsScid(
scid basetypes.ShortChannelID,
) bool {
s.mtx.RLock()
defer s.mtx.RUnlock()

_, ok := s.activeScids[uint64(scid)]
return ok
}

func (s *ScidCache) AddScids(
scids ...basetypes.ShortChannelID,
) {
s.mtx.Lock()
defer s.mtx.Unlock()

for _, scid := range scids {
s.activeScids[uint64(scid)] = true
}
}

func (s *ScidCache) RemoveScids(
scids ...basetypes.ShortChannelID,
) {
s.mtx.Lock()
defer s.mtx.Unlock()

for _, scid := range scids {
delete(s.activeScids, uint64(scid))
}
}

func (s *ScidCache) ReplaceScids(
scids []basetypes.ShortChannelID,
) {
s.mtx.Lock()
defer s.mtx.Unlock()

s.activeScids = make(map[uint64]bool)
for _, scid := range scids {
s.activeScids[uint64(scid)] = true
}
}
66 changes: 66 additions & 0 deletions lnd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,72 @@ func (c *LndClient) GetClosedChannels(nodeID string, channelPoints map[string]ui
return r, nil
}

func (c *LndClient) WatchScids(
ctx context.Context,
cache lightning.ScidCacheWriter,
) error {
stream, err := c.client.SubscribeChannelEvents(
ctx,
&lnrpc.ChannelEventSubscription{},
)
if err != nil {
return err
}

// Sync current scids first
chans, err := c.client.ListChannels(
context.Background(),
&lnrpc.ListChannelsRequest{},
)
if err != nil {
return err
}

var initialScids []basetypes.ShortChannelID
for _, ch := range chans.Channels {
initialScids = append(
initialScids,
mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)...)
}
cache.ReplaceScids(initialScids)

// Watch channels for changes
for {
upd, err := stream.Recv()
if err != nil {
return err
}

switch upd.Type {
case lnrpc.ChannelEventUpdate_OPEN_CHANNEL:
ch := upd.GetOpenChannel()
if ch == nil {
continue
}

scids := mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)
cache.AddScids(scids...)
case lnrpc.ChannelEventUpdate_CLOSED_CHANNEL:
ch := upd.GetClosedChannel()
if ch == nil {
continue
}

scids := mapToScids(append(ch.AliasScids, ch.ChanId, ch.ZeroConfConfirmedScid)...)
cache.RemoveScids(scids...)
}
}
}

func mapToScids(scids ...uint64) []basetypes.ShortChannelID {
var result []basetypes.ShortChannelID
for _, scid := range scids {
result = append(result, basetypes.ShortChannelID(scid))
}

return result
}

func (c *LndClient) getWaitingCloseChannels(nodeID string) ([]*lnrpc.PendingChannelsResponse_WaitingCloseChannel, error) {
pendingResponse, err := c.client.PendingChannels(context.Background(), &lnrpc.PendingChannelsRequest{})
if err != nil {
Expand Down
39 changes: 39 additions & 0 deletions lsps2/scid_cleaner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package lsps2

import (
"context"
"log"
"time"
)

type ScidCleaner struct {
interval time.Duration
store ScidStore
}

func NewScidCleaner(
store ScidStore,
interval time.Duration,
) *ScidCleaner {
return &ScidCleaner{
interval: interval,
store: store,
}
}

func (s *ScidCleaner) Start(ctx context.Context) error {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return ctx.Err()
case t := <-ticker.C:
err := s.store.RemoveExpired(ctx, t)
if err != nil {
log.Printf("Failed to remove expired scids: %v", err)
}
}
}
}
84 changes: 84 additions & 0 deletions lsps2/scid_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package lsps2

import (
"context"
"crypto/rand"
"fmt"
"log"
"math/big"
"time"

"github.com/breez/lspd/basetypes"
"github.com/breez/lspd/lightning"
)

const maxScidAttempts = 5

var two = big.NewInt(2)
var sixtyfour = big.NewInt(64)
var maxUint64 = two.Exp(two, sixtyfour, nil)

func newScid() (*basetypes.ShortChannelID, error) {
s, err := rand.Int(rand.Reader, maxUint64)
if err != nil {
return nil, err
}

scid := basetypes.ShortChannelID(s.Uint64())
return &scid, nil
}

type ScidService struct {
lspId []byte
store ScidStore
cache lightning.ScidCacheReader
}

func NewScidService(
lspId []byte,
store ScidStore,
cache lightning.ScidCacheReader,
) *ScidService {
return &ScidService{lspId: lspId, store: store, cache: cache}
}

func (s *ScidService) ReserveNewScid(
ctx context.Context,
expiry time.Time,
) (*basetypes.ShortChannelID, error) {
for attempts := 0; attempts < maxScidAttempts; attempts++ {
scid, err := newScid()
if err != nil {
log.Printf("NewScid() error: %v", err)
continue
}

if s.cache.ContainsScid(*scid) {
log.Printf(
"Collision with existing channel when generating new scid %s",
scid.ToString(),
)
continue
}

err = s.store.AddScid(ctx, s.lspId, *scid, expiry)
if err == ErrScidExists {
log.Printf(
"Collision on inserting random new scid %s",
scid.ToString(),
)
continue
}

if err != nil {
return nil, fmt.Errorf("failed to insert scid reservation: %w", err)
}

return scid, nil
}

return nil, fmt.Errorf(
"failed to reserve scid after %v attempts",
maxScidAttempts,
)
}
Loading