diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 999e43a1..b136366f 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -5,7 +5,7 @@ on: pull_request: env: BITCOIN_VERSION: '25.0' - LSP_REF: 'breez-node-v0.17.2-beta' + LSP_REF: 'unencrypted-failure-messages' CLIENT_REF: 'v0.16.4-breez-3' GO_VERSION: '^1.19' CLN_VERSION: 'v23.11' @@ -158,6 +158,7 @@ jobs: testLsps2ZeroConfUtxo ] lsp: [ + LND, CLN ] client: [ diff --git a/cln/cln_client.go b/cln/cln_client.go index 49e1da88..f45e9b57 100644 --- a/cln/cln_client.go +++ b/cln/cln_client.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/breez/lspd/common" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" @@ -88,8 +89,9 @@ func (c *ClnClient) GetInfo() (*lightning.GetInfoResult, error) { } return &lightning.GetInfoResult{ - Alias: info.Alias, - Pubkey: info.Id, + ChainHash: common.MapChainHash(info.Network), + Alias: info.Alias, + Pubkey: info.Id, }, nil } diff --git a/cln/cln_interceptor.go b/cln/cln_interceptor.go index f58ad239..267222e8 100644 --- a/cln/cln_interceptor.go +++ b/cln/cln_interceptor.go @@ -161,7 +161,7 @@ func (i *ClnHtlcInterceptor) intercept() error { interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) case common.INTERCEPT_FAIL_HTLC_WITH_CODE: interceptorClient.Send( - i.failWithCode(request, interceptResult.FailureCode), + i.failWithMessage(request, interceptResult.FailureMessage), ) case common.INTERCEPT_IGNORE: // Do nothing @@ -202,12 +202,12 @@ func (i *ClnHtlcInterceptor) resumeWithOnion(request *proto.HtlcAccepted, interc payload, err := hex.DecodeString(request.Onion.Payload) if err != nil { log.Printf("paymenthash: %s, resumeWithOnion: hex.DecodeString(%v) error: %v", request.Htlc.PaymentHash, request.Onion.Payload, err) - return i.failWithCode(request, common.FAILURE_TEMPORARY_CHANNEL_FAILURE) + return i.failWithMessage(request, common.FAILURE_TEMPORARY_CHANNEL_FAILURE) } newPayload, err := encodePayloadWithNextHop(payload, interceptResult.Scid, interceptResult.AmountMsat, interceptResult.FeeMsat) if err != nil { log.Printf("paymenthash: %s, encodePayloadWithNextHop error: %v", request.Htlc.PaymentHash, err) - return i.failWithCode(request, common.FAILURE_TEMPORARY_CHANNEL_FAILURE) + return i.failWithMessage(request, common.FAILURE_TEMPORARY_CHANNEL_FAILURE) } newPayloadStr := hex.EncodeToString(newPayload) @@ -234,14 +234,14 @@ func (i *ClnHtlcInterceptor) defaultResolution(request *proto.HtlcAccepted) *pro } } -func (i *ClnHtlcInterceptor) failWithCode(request *proto.HtlcAccepted, code common.InterceptFailureCode) *proto.HtlcResolution { - log.Printf("paymenthash: %s, failing htlc with code: '%x'", request.Htlc.PaymentHash, code) +func (i *ClnHtlcInterceptor) failWithMessage(request *proto.HtlcAccepted, message []byte) *proto.HtlcResolution { + log.Printf("paymenthash: %s, failing htlc with message '%x'", request.Htlc.PaymentHash, message) return &proto.HtlcResolution{ Correlationid: request.Correlationid, Outcome: &proto.HtlcResolution_Fail{ Fail: &proto.HtlcFail{ Failure: &proto.HtlcFail_FailureMessage{ - FailureMessage: i.mapFailureCode(code), + FailureMessage: hex.EncodeToString(message), }, }, }, @@ -305,23 +305,3 @@ func encodePayloadWithNextHop(payload []byte, scid lightning.ShortChannelID, amo } return newPayloadBuf.Bytes(), nil } - -func (i *ClnHtlcInterceptor) mapFailureCode(original common.InterceptFailureCode) string { - switch original { - case common.FAILURE_TEMPORARY_CHANNEL_FAILURE: - return "1007" - case common.FAILURE_AMOUNT_BELOW_MINIMUM: - return "100B" - case common.FAILURE_INCORRECT_CLTV_EXPIRY: - return "100D" - case common.FAILURE_TEMPORARY_NODE_FAILURE: - return "2002" - case common.FAILURE_UNKNOWN_NEXT_PEER: - return "400A" - case common.FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: - return "400F" - default: - log.Printf("Unknown failure code %v, default to temporary channel failure.", original) - return "1007" // temporary channel failure - } -} diff --git a/common/chain_hash.go b/common/chain_hash.go new file mode 100644 index 00000000..feafe565 --- /dev/null +++ b/common/chain_hash.go @@ -0,0 +1,19 @@ +package common + +import ( + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +func MapChainHash(network string) *chainhash.Hash { + switch network { + case "bitcoin": + return chaincfg.MainNetParams.GenesisHash + case "testnet": + return chaincfg.TestNet3Params.GenesisHash + case "regtest": + return chaincfg.RegressionNetParams.GenesisHash + default: + return chaincfg.MainNetParams.GenesisHash + } +} diff --git a/common/chan_update.go b/common/chan_update.go new file mode 100644 index 00000000..19ebae6c --- /dev/null +++ b/common/chan_update.go @@ -0,0 +1,38 @@ +package common + +import ( + "bytes" + "time" + + "github.com/breez/lspd/lightning" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/lnwire" +) + +func ConstructChanUpdate( + chainhash chainhash.Hash, + node []byte, + destination []byte, + scid lightning.ShortChannelID, + timeLockDelta uint16, + htlcMinimumMsat, + htlcMaximumMsat uint64, +) lnwire.ChannelUpdate { + channelFlags := lnwire.ChanUpdateChanFlags(0) + if bytes.Compare(node, destination) > 0 { + channelFlags = 1 + } + + return lnwire.ChannelUpdate{ + ChainHash: chainhash, + ShortChannelID: scid.ToLnwire(), + Timestamp: uint32(time.Now().Unix()), + TimeLockDelta: timeLockDelta, + HtlcMinimumMsat: lnwire.MilliSatoshi(htlcMinimumMsat), + HtlcMaximumMsat: lnwire.MilliSatoshi(htlcMaximumMsat), + BaseFee: 0, + FeeRate: 0, + MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + ChannelFlags: channelFlags, + } +} diff --git a/common/intercept_handler.go b/common/intercept_handler.go index 96ea2a3f..d5dbc899 100644 --- a/common/intercept_handler.go +++ b/common/intercept_handler.go @@ -1,10 +1,13 @@ package common import ( + "bytes" "fmt" + "log" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwire" ) type InterceptAction int @@ -16,17 +19,41 @@ const ( INTERCEPT_IGNORE InterceptAction = 3 ) -type InterceptFailureCode uint16 +type InterceptFailureCode []byte var ( - FAILURE_TEMPORARY_CHANNEL_FAILURE InterceptFailureCode = 0x1007 - FAILURE_AMOUNT_BELOW_MINIMUM InterceptFailureCode = 0x100B - FAILURE_INCORRECT_CLTV_EXPIRY InterceptFailureCode = 0x100D - FAILURE_TEMPORARY_NODE_FAILURE InterceptFailureCode = 0x2002 - FAILURE_UNKNOWN_NEXT_PEER InterceptFailureCode = 0x400A - FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS InterceptFailureCode = 0x400F + FAILURE_TEMPORARY_CHANNEL_FAILURE InterceptFailureCode = []byte{0x10, 0x07} + FAILURE_AMOUNT_BELOW_MINIMUM InterceptFailureCode = []byte{0x10, 0x0B} + FAILURE_INCORRECT_CLTV_EXPIRY InterceptFailureCode = []byte{0x10, 0x0D} + FAILURE_TEMPORARY_NODE_FAILURE InterceptFailureCode = []byte{0x20, 0x02} + FAILURE_UNKNOWN_NEXT_PEER InterceptFailureCode = []byte{0x40, 0x0A} + FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS InterceptFailureCode = []byte{0x40, 0x0F} ) +func FailureTemporaryChannelFailure(update *lnwire.ChannelUpdate) []byte { + var buf bytes.Buffer + msg := lnwire.NewTemporaryChannelFailure(update) + err := lnwire.EncodeFailureMessage(&buf, msg, 0) + if err != nil { + log.Printf("Failed to encode failure message for temporary channel failure: %v", err) + return FAILURE_TEMPORARY_CHANNEL_FAILURE + } + + return buf.Bytes() +} + +func FailureIncorrectCltvExpiry(cltvExpiry uint32, update lnwire.ChannelUpdate) []byte { + var buf bytes.Buffer + msg := lnwire.NewIncorrectCltvExpiry(cltvExpiry, update) + err := lnwire.EncodeFailureMessage(&buf, msg, 0) + if err != nil { + log.Printf("Failed to encode failure message for incorrect cltv expiry: %v", err) + return FAILURE_INCORRECT_CLTV_EXPIRY + } + + return buf.Bytes() +} + type InterceptRequest struct { // Identifier that uniquely identifies this htlc. // For cln, that's hash of the next onion or the shared secret. @@ -49,7 +76,7 @@ func (r *InterceptRequest) HtlcId() string { type InterceptResult struct { Action InterceptAction - FailureCode InterceptFailureCode + FailureMessage []byte Destination []byte AmountMsat uint64 FeeMsat *uint64 diff --git a/common/nodes_service.go b/common/nodes_service.go index e983c3c0..128b8672 100644 --- a/common/nodes_service.go +++ b/common/nodes_service.go @@ -7,11 +7,14 @@ import ( "github.com/breez/lspd/config" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" ecies "github.com/ecies/go/v2" "golang.org/x/sync/singleflight" ) type Node struct { + NodeId []byte + ChainHash chainhash.Hash Client lightning.Client NodeConfig *config.NodeConfig PrivateKey *btcec.PrivateKey diff --git a/go.mod b/go.mod index 66370c03..82bdfeda 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/docker/docker v20.10.24+incompatible github.com/docker/go-connections v0.4.0 github.com/elementsproject/glightning v0.0.0-20230525134205-ef34d849f564 - github.com/golang/protobuf v1.5.3 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/jackc/pgtype v1.14.0 github.com/jackc/pgx/v5 v5.4.3 @@ -38,6 +37,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/ethereum/go-ethereum v1.13.5 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/google/uuid v1.3.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.0 // indirect @@ -194,7 +194,7 @@ require ( sigs.k8s.io/yaml v1.2.0 // indirect ) -replace github.com/lightningnetwork/lnd v0.17.2-beta => github.com/breez/lnd v0.15.0-beta.rc6.0.20231122093500-0c939786ced7 +replace github.com/lightningnetwork/lnd v0.17.2-beta => github.com/breez/lnd v0.15.0-beta.rc6.0.20240105103917-ec16df3d9d48 replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display diff --git a/interceptor/intercept_handler.go b/interceptor/intercept_handler.go index 08120f43..26ddb62e 100644 --- a/interceptor/intercept_handler.go +++ b/interceptor/intercept_handler.go @@ -11,7 +11,6 @@ import ( "github.com/breez/lspd/chain" "github.com/breez/lspd/common" - "github.com/breez/lspd/config" "github.com/breez/lspd/lightning" "github.com/breez/lspd/lsps0" "github.com/breez/lspd/notifications" @@ -22,7 +21,7 @@ import ( type Interceptor struct { client lightning.Client - config *config.NodeConfig + node *common.Node store InterceptStore openingService common.OpeningService feeEstimator chain.FeeEstimator @@ -33,7 +32,7 @@ type Interceptor struct { func NewInterceptHandler( client lightning.Client, - config *config.NodeConfig, + node *common.Node, store InterceptStore, openingService common.OpeningService, feeEstimator chain.FeeEstimator, @@ -42,7 +41,7 @@ func NewInterceptHandler( ) *Interceptor { return &Interceptor{ client: client, - config: config, + node: node, store: store, openingService: openingService, feeEstimator: feeEstimator, @@ -59,8 +58,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if err != nil { log.Printf("paymentInfo(%x) error: %v", req.PaymentHash, err) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_NODE_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FAILURE_TEMPORARY_NODE_FAILURE, }, nil } @@ -111,8 +110,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if err != nil { log.Printf("IsConnected(%x) error: %v", nextHop, err) return &common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureTemporaryChannelFailure(nil), }, nil } @@ -133,8 +132,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if channelPoint == nil { log.Printf("paymentHash: %s, probe and channelPoint == nil", reqPaymentHashStr) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, }, nil } } @@ -161,26 +160,37 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes }, nil } + // In case we fail with an error, this is the used channel update. + chanUpdate := common.ConstructChanUpdate( + i.node.ChainHash, + i.node.NodeId, + destination, + req.Scid, + uint16(i.node.NodeConfig.TimeLockDelta), + i.node.NodeConfig.MinHtlcMsat, + req.IncomingAmountMsat, + ) + // The first htlc of a MPP will open the channel. if channelPoint == nil { // TODO: When opening_fee_params is enforced, turn this check in a temporary channel failure. if params == nil { log.Printf("DEPRECATED: Intercepted htlc with deprecated fee mechanism. Using default fees. payment hash: %s", reqPaymentHashStr) params = &common.OpeningFeeParams{ - MinFeeMsat: uint64(i.config.ChannelMinimumFeeMsat), - Proportional: uint32(i.config.ChannelFeePermyriad * 100), + MinFeeMsat: uint64(i.node.NodeConfig.ChannelMinimumFeeMsat), + Proportional: uint32(i.node.NodeConfig.ChannelFeePermyriad * 100), ValidUntil: time.Now().UTC().Add(time.Duration(time.Hour * 24)).Format(lsps0.TIME_FORMAT), - MinLifetime: uint32(i.config.MaxInactiveDuration / 600), + MinLifetime: uint32(i.node.NodeConfig.MaxInactiveDuration / 600), MaxClientToSelfDelay: uint32(10000), } } // Make sure the cltv delta is enough. - if int64(req.IncomingExpiry)-int64(req.OutgoingExpiry) < int64(i.config.TimeLockDelta) { - log.Printf("paymentHash: %s, outgoingExpiry: %v, incomingExpiry: %v, i.config.TimeLockDelta: %v", reqPaymentHashStr, req.OutgoingExpiry, req.IncomingExpiry, i.config.TimeLockDelta) + if int64(req.IncomingExpiry)-int64(req.OutgoingExpiry) < int64(i.node.NodeConfig.TimeLockDelta) { + log.Printf("paymentHash: %s, outgoingExpiry: %v, incomingExpiry: %v, i.node.NodeConfig.TimeLockDelta: %v", reqPaymentHashStr, req.OutgoingExpiry, req.IncomingExpiry, i.node.NodeConfig.TimeLockDelta) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureIncorrectCltvExpiry(req.IncomingExpiry, chanUpdate), }, nil } @@ -188,8 +198,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if err != nil { log.Printf("paymentHash: %s, time.Parse(%s, %s) failed. Failing channel open: %v", reqPaymentHashStr, lsps0.TIME_FORMAT, params.ValidUntil, err) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } @@ -199,8 +209,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if !i.openingService.IsCurrentChainFeeCheaper(token, params) { log.Printf("Intercepted expired payment registration. Failing payment. payment hash: %s, valid until: %s", reqPaymentHashStr, params.ValidUntil) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } @@ -211,8 +221,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if err != nil { log.Printf("paymentHash: %s, openChannel(%x, %v) err: %v", reqPaymentHashStr, destination, incomingAmountMsat, err) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } } @@ -238,8 +248,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes if err != nil { log.Printf("paymentHash: %s, insertChannel error: %v", reqPaymentHashStr, err) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } @@ -248,7 +258,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes channelID = chanResult.InitialChannelID } - useLegacyOnionBlob := slices.Contains(i.config.LegacyOnionTokens, token) + useLegacyOnionBlob := slices.Contains(i.node.NodeConfig.LegacyOnionTokens, token) return common.InterceptResult{ Action: common.INTERCEPT_RESUME_WITH_ONION, Destination: destination, @@ -271,8 +281,8 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("paymentHash: %s, Error: Channel failed to open... timed out. ", reqPaymentHashStr) return common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil }) @@ -297,7 +307,7 @@ func (i *Interceptor) notifyAndWait(reqPaymentHashStr string, nextHop []byte, is } log.Printf("paymentHash %s, Notified %x of pending htlc", reqPaymentHashStr, nextHop) - d, err := time.ParseDuration(i.config.NotificationTimeout) + d, err := time.ParseDuration(i.node.NodeConfig.NotificationTimeout) if err != nil { log.Printf("WARN: No NotificationTimeout set. Using default 1m") d = time.Minute @@ -349,8 +359,8 @@ func (i *Interceptor) notifyAndWait(reqPaymentHashStr string, nextHop []byte, is } func (i *Interceptor) openChannel(paymentHash, destination []byte, incomingAmountMsat int64, tag *string) (*wire.OutPoint, error) { - capacity := incomingAmountMsat/1000 + i.config.AdditionalChannelCapacity - if capacity == i.config.PublicChannelAmount { + capacity := incomingAmountMsat/1000 + i.node.NodeConfig.AdditionalChannelCapacity + if capacity == i.node.NodeConfig.PublicChannelAmount { capacity++ } @@ -368,7 +378,7 @@ func (i *Interceptor) openChannel(paymentHash, destination []byte, incomingAmoun feeStr = fmt.Sprintf("%.5f", *feeEstimation) } else { log.Printf("Error estimating chain fee, fallback to target conf: %v", err) - targetConf = &i.config.TargetConf + targetConf = &i.node.NodeConfig.TargetConf confStr = fmt.Sprintf("%v", *targetConf) } } @@ -384,7 +394,7 @@ func (i *Interceptor) openChannel(paymentHash, destination []byte, incomingAmoun channelPoint, err := i.client.OpenChannel(&lightning.OpenChannelRequest{ Destination: destination, CapacitySat: uint64(capacity), - MinConfs: i.config.MinConfs, + MinConfs: i.node.NodeConfig.MinConfs, IsPrivate: true, IsZeroConf: true, FeeSatPerVByte: feeEstimation, diff --git a/itest/cltv_test.go b/itest/cltv_test.go index bbe0f5c9..780e29f8 100644 --- a/itest/cltv_test.go +++ b/itest/cltv_test.go @@ -54,5 +54,5 @@ func testInvalidCltv(p *testParams) { // Decrement the delay in the first hop, so the cltv delta will become 143 (too little) route.Hops[0].Delay-- _, err := alice.PayViaRoute(outerAmountMsat, outerInvoice.paymentHash, outerInvoice.paymentSecret, route) - assert.Contains(p.t, err.Error(), "WIRE_TEMPORARY_CHANNEL_FAILURE") + assert.Contains(p.t, err.Error(), "WIRE_INCORRECT_CLTV_EXPIRY") } diff --git a/itest/lnd_breez_client.go b/itest/lnd_breez_client.go index c86be474..f6da55d9 100644 --- a/itest/lnd_breez_client.go +++ b/itest/lnd_breez_client.go @@ -2,8 +2,11 @@ package itest import ( "context" + "encoding/hex" "flag" + "log" "sync" + "time" "github.com/breez/lntest" "github.com/breez/lntest/lnd" @@ -15,11 +18,12 @@ var lndMobileExecutable = flag.String( ) type lndBreezClient struct { - name string - harness *lntest.TestHarness - node *lntest.LndNode - cancel context.CancelFunc - mtx sync.Mutex + name string + harness *lntest.TestHarness + node *lntest.LndNode + customMsgQueue chan *lntest.CustomMsgRequest + cancel context.CancelFunc + mtx sync.Mutex } func newLndBreezClient(h *lntest.TestHarness, m *lntest.Miner, name string) BreezClient { @@ -63,6 +67,8 @@ func (c *lndBreezClient) Start() { ctx, cancel := context.WithCancel(c.harness.Ctx) c.cancel = cancel go c.startChannelAcceptor(ctx) + c.customMsgQueue = make(chan *lntest.CustomMsgRequest, 100) + c.startCustomMsgListener(ctx) } func (c *lndBreezClient) Stop() error { @@ -86,9 +92,51 @@ func (c *lndBreezClient) SetHtlcAcceptor(totalMsat uint64) { // No need for a htlc acceptor in the LND breez client } +func (c *lndBreezClient) startCustomMsgListener(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + } + + if !c.node.IsStarted() { + log.Printf("%s: cannot listen to custom messages, node is not started.", c.name) + break + } + + listener, err := c.node.LightningClient().SubscribeCustomMessages( + ctx, + &lnd.SubscribeCustomMessagesRequest{}, + ) + if err != nil { + log.Printf("%s: client.SubscribeCustomMessages() error: %v", c.name, err) + break + } + for { + if ctx.Err() != nil { + return + } + msg, err := listener.Recv() + if err != nil { + log.Printf("%s: listener.Recv() error: %v", c.name, err) + break + } + + c.customMsgQueue <- &lntest.CustomMsgRequest{ + PeerId: hex.EncodeToString(msg.Peer), + Type: msg.Type, + Data: msg.Data, + } + } + } + }() +} + func (c *lndBreezClient) ReceiveCustomMessage() *lntest.CustomMsgRequest { - // TODO: Not implemented. - return nil + msg := <-c.customMsgQueue + return msg } func (c *lndBreezClient) startChannelAcceptor(ctx context.Context) error { @@ -98,6 +146,10 @@ func (c *lndBreezClient) startChannelAcceptor(ctx context.Context) error { } for { + if ctx.Err() != nil { + return ctx.Err() + } + request, err := client.Recv() if err != nil { return err diff --git a/itest/lspd_test.go b/itest/lspd_test.go index 87bf0040..9c0cecb6 100644 --- a/itest/lspd_test.go +++ b/itest/lspd_test.go @@ -24,7 +24,7 @@ func TestLspd(t *testing.T) { lndTestCases = append(lndTestCases, c) } } - runTests(t, lndTestCases, "LND-lsp-CLN-client", lndLspFunc, clnClientFunc) + runTests(t, testCases, "LND-lsp-CLN-client", lndLspFunc, clnClientFunc) runTests(t, lndTestCases, "LND-lsp-LND-client", legacyOnionLndLspFunc, lndClientFunc) runTests(t, testCases, "CLN-lsp-CLN-client", clnLspFunc, clnClientFunc) } diff --git a/itest/notification_test.go b/itest/notification_test.go index d116b214..fbc6fe9f 100644 --- a/itest/notification_test.go +++ b/itest/notification_test.go @@ -137,10 +137,10 @@ func testOfflineNotificationRegularForward(p *testParams) { p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode()) }() + <-time.After(time.Second * 2) url := "http://" + addr + "/api/v1/notify" SubscribeNotifications(p.lsp, p.BreezClient(), url, false) - <-time.After(time.Second * 2) log.Printf("Adding bob's invoice") amountMsat := uint64(2100000) bobInvoice := p.BreezClient().Node().CreateBolt11Invoice(&lntest.CreateInvoiceOptions{ diff --git a/lightning/client.go b/lightning/client.go index 236f0cd3..d8351336 100644 --- a/lightning/client.go +++ b/lightning/client.go @@ -3,12 +3,14 @@ package lightning import ( "time" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" ) type GetInfoResult struct { - Alias string - Pubkey string + ChainHash *chainhash.Hash + Alias string + Pubkey string } type GetChannelResult struct { diff --git a/lightning/short_channel_id.go b/lightning/short_channel_id.go index 45ca07ff..f4ced868 100644 --- a/lightning/short_channel_id.go +++ b/lightning/short_channel_id.go @@ -49,3 +49,7 @@ func (c *ShortChannelID) ToString() string { outputIndex := u & 0xFFFF return fmt.Sprintf("%dx%dx%d", blockHeight, txIndex, outputIndex) } + +func (c *ShortChannelID) ToLnwire() lnwire.ShortChannelID { + return lnwire.NewShortChanIDFromInt(uint64(*c)) +} diff --git a/lnd/client.go b/lnd/client.go index 0006da3f..84d3ce96 100644 --- a/lnd/client.go +++ b/lnd/client.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/breez/lspd/common" "github.com/breez/lspd/config" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -220,8 +221,9 @@ func (c *LndClient) GetInfo() (*lightning.GetInfoResult, error) { } return &lightning.GetInfoResult{ - Alias: info.Alias, - Pubkey: info.IdentityPubkey, + ChainHash: common.MapChainHash(info.Chains[0].Network), + Alias: info.Alias, + Pubkey: info.IdentityPubkey, }, nil } diff --git a/lnd/custom_msg_client.go b/lnd/custom_msg_client.go new file mode 100644 index 00000000..2f038780 --- /dev/null +++ b/lnd/custom_msg_client.go @@ -0,0 +1,149 @@ +package lnd + +import ( + "context" + "encoding/hex" + "fmt" + "log" + "sync" + "time" + + "github.com/breez/lspd/config" + "github.com/breez/lspd/lightning" + "github.com/lightningnetwork/lnd/lnrpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type CustomMsgClient struct { + lightning.CustomMsgClient + client *LndClient + initWg sync.WaitGroup + stopRequested bool + ctx context.Context + cancel context.CancelFunc + recvQueue chan *lightning.CustomMessage +} + +func NewCustomMsgClient(conf *config.ClnConfig, client *LndClient) *CustomMsgClient { + c := &CustomMsgClient{ + client: client, + recvQueue: make(chan *lightning.CustomMessage, 10000), + } + + c.initWg.Add(1) + return c +} + +func (c *CustomMsgClient) Start() error { + ctx, cancel := context.WithCancel(context.Background()) + c.ctx = ctx + c.cancel = cancel + c.stopRequested = false + return c.listen() +} + +func (i *CustomMsgClient) WaitStarted() { + i.initWg.Wait() +} + +func (i *CustomMsgClient) listen() error { + inited := false + + defer func() { + if !inited { + i.initWg.Done() + } + log.Printf("CLN custom msg listen(): stopping.") + }() + + for { + if i.ctx.Err() != nil { + return i.ctx.Err() + } + + log.Printf("Connecting LND custom msg stream.") + msgClient, err := i.client.client.SubscribeCustomMessages( + i.ctx, + &lnrpc.SubscribeCustomMessagesRequest{}, + ) + if err != nil { + log.Printf("client.SubscribeCustomMessages(): %v", err) + <-time.After(time.Second) + continue + } + + for { + if i.ctx.Err() != nil { + return i.ctx.Err() + } + + if !inited { + inited = true + i.initWg.Done() + } + + // Stop receiving if stop if requested. + if i.stopRequested { + return nil + } + + request, err := msgClient.Recv() + if err != nil { + // If it is just the error result of the context cancellation + // the we exit silently. + status, ok := status.FromError(err) + if ok && status.Code() == codes.Canceled { + log.Printf("Got code canceled. Break.") + break + } + + // Otherwise it an unexpected error, we log. + log.Printf("unexpected error in interceptor.Recv() %v", err) + break + } + + i.recvQueue <- &lightning.CustomMessage{ + PeerId: hex.EncodeToString(request.Peer), + Type: request.Type, + Data: request.Data, + } + } + + <-time.After(time.Second) + } +} + +func (c *CustomMsgClient) Recv() (*lightning.CustomMessage, error) { + select { + case msg := <-c.recvQueue: + return msg, nil + case <-c.ctx.Done(): + return nil, c.ctx.Err() + } +} + +func (c *CustomMsgClient) Send(msg *lightning.CustomMessage) error { + peerId, err := hex.DecodeString(msg.PeerId) + if err != nil { + return fmt.Errorf("hex.DecodeString(%s) err: %w", msg.PeerId, err) + } + _, err = c.client.client.SendCustomMessage( + c.ctx, + &lnrpc.SendCustomMessageRequest{ + Peer: peerId, + Type: msg.Type, + Data: msg.Data, + }, + ) + return err +} + +func (i *CustomMsgClient) Stop() error { + // Setting stopRequested to true will make the interceptor stop receiving. + i.stopRequested = true + + // Close the grpc connection. + i.cancel() + return nil +} diff --git a/lnd/interceptor.go b/lnd/interceptor.go index c3eef64b..d689f00c 100644 --- a/lnd/interceptor.go +++ b/lnd/interceptor.go @@ -9,7 +9,6 @@ import ( "github.com/breez/lspd/common" "github.com/breez/lspd/config" - "github.com/breez/lspd/interceptor" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" @@ -24,7 +23,7 @@ import ( type LndHtlcInterceptor struct { fwsync *ForwardingHistorySync - interceptor *interceptor.Interceptor + interceptor common.InterceptHandler config *config.NodeConfig client *LndClient stopRequested bool @@ -38,7 +37,7 @@ func NewLndHtlcInterceptor( conf *config.NodeConfig, client *LndClient, fwsync *ForwardingHistorySync, - interceptor *interceptor.Interceptor, + interceptor common.InterceptHandler, ) (*LndHtlcInterceptor, error) { i := &LndHtlcInterceptor{ config: conf, @@ -150,11 +149,12 @@ func (i *LndHtlcInterceptor) intercept() error { case common.INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(i.createOnionResponse(interceptResult, request)) case common.INTERCEPT_FAIL_HTLC_WITH_CODE: - log.Printf("paymenthash %x, failing htlc with code '%x'", request.PaymentHash, interceptResult.FailureCode) + log.Printf("paymenthash %x, failing htlc with message '%x'", request.PaymentHash, interceptResult.FailureMessage) interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ - IncomingCircuitKey: request.IncomingCircuitKey, - Action: routerrpc.ResolveHoldForwardAction_FAIL, - FailureCode: i.mapFailureCode(interceptResult.FailureCode), + IncomingCircuitKey: request.IncomingCircuitKey, + Action: routerrpc.ResolveHoldForwardAction_FAIL, + FailureMessage: interceptResult.FailureMessage, + FailureMessageUnencrypted: true, }) case common.INTERCEPT_RESUME: fallthrough @@ -176,20 +176,6 @@ func (i *LndHtlcInterceptor) intercept() error { } } -func (i *LndHtlcInterceptor) mapFailureCode(original common.InterceptFailureCode) lnrpc.Failure_FailureCode { - switch original { - case common.FAILURE_TEMPORARY_CHANNEL_FAILURE: - return lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE - case common.FAILURE_TEMPORARY_NODE_FAILURE: - return lnrpc.Failure_TEMPORARY_NODE_FAILURE - case common.FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: - return lnrpc.Failure_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS - default: - log.Printf("Unknown failure code %v, default to temporary channel failure.", original) - return lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE - } -} - func (i *LndHtlcInterceptor) constructOnion( interceptResult common.InterceptResult, reqOutgoingExpiry uint32, diff --git a/lsps2/intercept_handler.go b/lsps2/intercept_handler.go index 4a83cf5e..f5aa96dd 100644 --- a/lsps2/intercept_handler.go +++ b/lsps2/intercept_handler.go @@ -13,10 +13,13 @@ import ( "github.com/breez/lspd/common" "github.com/breez/lspd/lightning" "github.com/breez/lspd/lsps0" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" ) type InterceptorConfig struct { + NodeId []byte + ChainHash chainhash.Hash AdditionalChannelCapacitySat uint64 MinConfs *uint32 TargetConf uint32 @@ -111,7 +114,7 @@ type paymentChanOpenedEvent struct { type paymentFailureEvent struct { paymentId string - code common.InterceptFailureCode + message common.InterceptFailureCode } func (i *Interceptor) Start(ctx context.Context) { @@ -132,7 +135,7 @@ func (i *Interceptor) Start(ctx context.Context) { case paymentId := <-i.paymentReady: i.handlePaymentReady(paymentId) case ev := <-i.paymentFailure: - i.handlePaymentFailure(ev.paymentId, ev.code) + i.handlePaymentFailure(ev.paymentId, ev.message) case ev := <-i.paymentChanOpened: i.handlePaymentChanOpened(ev) } @@ -186,7 +189,7 @@ func (i *Interceptor) handleNewPart(part *partState) { // a goroutine. i.paymentFailure <- &paymentFailureEvent{ paymentId: paymentId, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(nil), } case <-payment.timeoutChan: // Stop listening for timeouts when the payment is ready. @@ -281,7 +284,17 @@ func (i *Interceptor) processPart(payment *paymentState, part *partState) { // Make sure the cltv delta is enough (actual cltv delta + 2). if int64(part.req.IncomingExpiry)-int64(part.req.OutgoingExpiry) < int64(i.config.TimeLockDelta)+2 { - i.failPart(payment, part, common.FAILURE_INCORRECT_CLTV_EXPIRY) + peerid, _ := hex.DecodeString(payment.registration.PeerId) + chanUpdate := common.ConstructChanUpdate( + i.config.ChainHash, + i.config.NodeId, + peerid, + payment.fakeScid, + uint16(i.config.TimeLockDelta), + i.config.HtlcMinimumMsat, + payment.paymentSizeMsat, + ) + i.failPart(payment, part, common.FailureIncorrectCltvExpiry(part.req.IncomingExpiry, chanUpdate)) return } @@ -373,6 +386,16 @@ func (i *Interceptor) handlePaymentReady(paymentId string) { // a goroutine. func (i *Interceptor) ensureChannelOpen(payment *paymentState) { destination, _ := hex.DecodeString(payment.registration.PeerId) + peerid, _ := hex.DecodeString(payment.registration.PeerId) + chanUpdate := common.ConstructChanUpdate( + i.config.ChainHash, + i.config.NodeId, + peerid, + payment.fakeScid, + uint16(i.config.TimeLockDelta), + i.config.HtlcMinimumMsat, + payment.paymentSizeMsat, + ) if payment.registration.ChannelPoint == nil { @@ -389,7 +412,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_UNKNOWN_NEXT_PEER, + message: common.FAILURE_UNKNOWN_NEXT_PEER, } return } @@ -409,7 +432,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_UNKNOWN_NEXT_PEER, + message: common.FAILURE_UNKNOWN_NEXT_PEER, } return } @@ -464,8 +487,8 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) code := common.FAILURE_UNKNOWN_NEXT_PEER - if strings.Contains(err.Error(), "not enough funds") { - code = common.FAILURE_TEMPORARY_CHANNEL_FAILURE + if strings.Contains(err.Error(), "not enough") { + code = common.FailureTemporaryChannelFailure(&chanUpdate) } // TODO: Verify that a client disconnect before receiving @@ -476,7 +499,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { // temporary_channel_failure should be returned. i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: code, + message: code, } return } @@ -499,7 +522,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(&chanUpdate), } return } @@ -521,7 +544,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { case <-time.After(time.Until(deadline)): i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(&chanUpdate), } return } @@ -600,11 +623,22 @@ func (i *Interceptor) handlePaymentChanOpened(event *paymentChanOpenedEvent) { event.paymentId, feeRemainingMsat, ) + + peerid, _ := hex.DecodeString(payment.registration.PeerId) + chanUpdate := common.ConstructChanUpdate( + i.config.ChainHash, + i.config.NodeId, + peerid, + payment.fakeScid, + uint16(i.config.TimeLockDelta), + i.config.HtlcMinimumMsat, + payment.paymentSizeMsat, + ) // TODO: Verify temporary_channel_failure is the way to go here, maybe // unknown_next_peer is more appropriate. i.paymentFailure <- &paymentFailureEvent{ paymentId: event.paymentId, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(&chanUpdate), } return } @@ -620,11 +654,11 @@ func (i *Interceptor) handlePaymentChanOpened(event *paymentChanOpenedEvent) { func (i *Interceptor) handlePaymentFailure( paymentId string, - code common.InterceptFailureCode, + message []byte, ) { i.finalizeAllParts(paymentId, &common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: code, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: message, }) } @@ -660,11 +694,11 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes func (i *Interceptor) failPart( payment *paymentState, part *partState, - code common.InterceptFailureCode, + message []byte, ) { part.resolution <- &common.InterceptResult{ - Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureCode: code, + Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureMessage: message, } delete(payment.parts, part.req.HtlcId()) if len(payment.parts) == 0 { diff --git a/lsps2/intercept_test.go b/lsps2/intercept_test.go index 4f136193..f5a800cf 100644 --- a/lsps2/intercept_test.go +++ b/lsps2/intercept_test.go @@ -284,7 +284,7 @@ func Test_NoMpp_AmtBelowMinimum(t *testing.T) { res := i.Intercept(createPart(&part{amt: defaultMinViableAmount - 1})) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureMessage) assertEmpty(t, i) } @@ -309,7 +309,7 @@ func Test_NoMpp_AmtAboveMaximum(t *testing.T) { res := i.Intercept(createPart(&part{amt: defaultConfig().MaxPaymentSizeMsat + 1})) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureMessage) assertEmpty(t, i) } @@ -322,7 +322,7 @@ func Test_NoMpp_CltvDeltaBelowMinimum(t *testing.T) { res := i.Intercept(createPart(&part{cltvDelta: 145})) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureMessage[:2]) assertEmpty(t, i) } @@ -350,7 +350,7 @@ func Test_NoMpp_ParamsExpired(t *testing.T) { res := i.Intercept(createPart(nil)) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureMessage) assertEmpty(t, i) } @@ -376,7 +376,7 @@ func Test_NoMpp_ChannelAlreadyOpened_Complete_Fails(t *testing.T) { res := i.Intercept(createPart(nil)) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureMessage) assertEmpty(t, i) } @@ -408,7 +408,7 @@ func Test_Mpp_SinglePart_AmtTooSmall(t *testing.T) { res := i.Intercept(createPart(&part{amt: defaultPaymentSizeMsat - 1})) end := time.Now() assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_TEMPORARY_CHANNEL_FAILURE, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_TEMPORARY_CHANNEL_FAILURE, res.FailureMessage[:2]) assert.GreaterOrEqual(t, end.Sub(start).Milliseconds(), config.MppTimeout.Milliseconds()) assertEmpty(t, i) } @@ -519,7 +519,7 @@ func Test_Mpp_BadSecondPart_ThirdPartCompletes(t *testing.T) { assert.Equal(t, defaultFee, *res1.FeeMsat) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res2.Action) - assert.Equal(t, common.FAILURE_AMOUNT_BELOW_MINIMUM, res2.FailureCode) + assert.ElementsMatch(t, common.FAILURE_AMOUNT_BELOW_MINIMUM, res2.FailureMessage) assert.Equal(t, common.INTERCEPT_RESUME_WITH_ONION, res3.Action) assert.Equal(t, defaultConfig().HtlcMinimumMsat, res3.AmountMsat) @@ -541,7 +541,7 @@ func Test_Mpp_CltvDeltaBelowMinimum(t *testing.T) { res := i.Intercept(createPart(&part{cltvDelta: 145})) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureMessage[:2]) assertEmpty(t, i) } @@ -569,7 +569,7 @@ func Test_Mpp_ParamsExpired(t *testing.T) { res := i.Intercept(createPart(nil)) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureMessage) assertEmpty(t, i) } @@ -605,9 +605,9 @@ func Test_Mpp_ParamsExpireInFlight(t *testing.T) { wg.Wait() assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res1.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res1.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res1.FailureMessage) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res2.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res2.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res2.FailureMessage) assertEmpty(t, i) } @@ -698,7 +698,7 @@ func Test_Mpp_ChannelAlreadyOpened_Complete_Fails(t *testing.T) { res := i.Intercept(createPart(nil)) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.Equal(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assert.ElementsMatch(t, common.FAILURE_UNKNOWN_NEXT_PEER, res.FailureMessage) assertEmpty(t, i) } diff --git a/main.go b/main.go index d0bfa7e8..519a85df 100644 --- a/main.go +++ b/main.go @@ -111,6 +111,25 @@ func main() { var interceptors []interceptor.HtlcInterceptor for _, node := range nodes { var htlcInterceptor interceptor.HtlcInterceptor + lsps2Config := &lsps2.InterceptorConfig{ + NodeId: node.NodeId, + ChainHash: node.ChainHash, + AdditionalChannelCapacitySat: uint64(node.NodeConfig.AdditionalChannelCapacity), + MinConfs: node.NodeConfig.MinConfs, + TargetConf: node.NodeConfig.TargetConf, + FeeStrategy: feeStrategy, + MinPaymentSizeMsat: node.NodeConfig.MinPaymentSizeMsat, + MaxPaymentSizeMsat: node.NodeConfig.MaxPaymentSizeMsat, + TimeLockDelta: node.NodeConfig.TimeLockDelta, + HtlcMinimumMsat: node.NodeConfig.MinHtlcMsat, + MppTimeout: time.Second * 90, + } + msgServer := lsps0.NewServer() + protocolServer := lsps0.NewProtocolServer([]uint32{2}) + lsps2Server := lsps2.NewLsps2Server(openingService, nodesService, node, lsps2Store) + lsps0.RegisterProtocolServer(msgServer, protocolServer) + lsps2.RegisterLsps2Server(msgServer, lsps2Server) + if node.NodeConfig.Lnd != nil { client, err := lnd.NewLndClient(node.NodeConfig.Lnd) if err != nil { @@ -119,11 +138,20 @@ func main() { client.StartListeners() fwsync := lnd.NewForwardingHistorySync(client, interceptStore, forwardingStore) - interceptor := interceptor.NewInterceptHandler(client, node.NodeConfig, interceptStore, openingService, feeEstimator, feeStrategy, notificationService) - htlcInterceptor, err = lnd.NewLndHtlcInterceptor(node.NodeConfig, client, fwsync, interceptor) + legacyHandler := interceptor.NewInterceptHandler(client, node, interceptStore, openingService, feeEstimator, feeStrategy, notificationService) + lsps2Handler := lsps2.NewInterceptHandler(lsps2Store, openingService, client, feeEstimator, lsps2Config) + go lsps2Handler.Start(ctx) + combinedHandler := common.NewCombinedHandler(lsps2Handler, legacyHandler) + htlcInterceptor, err = lnd.NewLndHtlcInterceptor(node.NodeConfig, client, fwsync, combinedHandler) if err != nil { log.Fatalf("failed to initialize LND interceptor: %v", err) } + + msgClient := lnd.NewCustomMsgClient(node.NodeConfig.Cln, client) + go msgClient.Start() + msgClient.WaitStarted() + defer msgClient.Stop() + go msgServer.Serve(msgClient) } if node.NodeConfig.Cln != nil { @@ -132,18 +160,8 @@ func main() { log.Fatalf("failed to initialize CLN client: %v", err) } - legacyHandler := interceptor.NewInterceptHandler(client, node.NodeConfig, interceptStore, openingService, feeEstimator, feeStrategy, notificationService) - lsps2Handler := lsps2.NewInterceptHandler(lsps2Store, openingService, client, feeEstimator, &lsps2.InterceptorConfig{ - AdditionalChannelCapacitySat: uint64(node.NodeConfig.AdditionalChannelCapacity), - MinConfs: node.NodeConfig.MinConfs, - TargetConf: node.NodeConfig.TargetConf, - FeeStrategy: feeStrategy, - MinPaymentSizeMsat: node.NodeConfig.MinPaymentSizeMsat, - MaxPaymentSizeMsat: node.NodeConfig.MaxPaymentSizeMsat, - TimeLockDelta: node.NodeConfig.TimeLockDelta, - HtlcMinimumMsat: node.NodeConfig.MinHtlcMsat, - MppTimeout: time.Second * 90, - }) + legacyHandler := interceptor.NewInterceptHandler(client, node, interceptStore, openingService, feeEstimator, feeStrategy, notificationService) + lsps2Handler := lsps2.NewInterceptHandler(lsps2Store, openingService, client, feeEstimator, lsps2Config) go lsps2Handler.Start(ctx) combinedHandler := common.NewCombinedHandler(lsps2Handler, legacyHandler) htlcInterceptor, err = cln.NewClnHtlcInterceptor(node.NodeConfig, client, combinedHandler) @@ -153,11 +171,6 @@ func main() { msgClient := cln.NewCustomMsgClient(node.NodeConfig.Cln, client) go msgClient.Start() - msgServer := lsps0.NewServer() - protocolServer := lsps0.NewProtocolServer([]uint32{2}) - lsps2Server := lsps2.NewLsps2Server(openingService, nodesService, node, lsps2Store) - lsps0.RegisterProtocolServer(msgServer, protocolServer) - lsps2.RegisterLsps2Server(msgServer, lsps2Server) msgClient.WaitStarted() defer msgClient.Stop() go msgServer.Serve(msgClient) @@ -297,7 +310,13 @@ func initializeNodes(configs []*config.NodeConfig) ([]*common.Node, error) { node.NodeConfig.NodePubkey = info.Pubkey } + nodeid, err := hex.DecodeString(info.Pubkey) + if err != nil { + return nil, fmt.Errorf("failed to decode node id '%s'", info.Pubkey) + } + node.ChainHash = *info.ChainHash node.Tokens = config.Tokens + node.NodeId = nodeid nodes = append(nodes, node) }