diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index b1e3d86520..f684234199 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -66,6 +66,12 @@ database](https://github.com/lightningnetwork/lnd/pull/9147) * Implement query methods (QueryPayments,FetchPayment) for the [payments db SQL Backend](https://github.com/lightningnetwork/lnd/pull/10287) + * Implement insert methods for the [payments db + SQL Backend](https://github.com/lightningnetwork/lnd/pull/10291) + * Implement third(final) Part of SQL backend [payment + functions](https://github.com/lightningnetwork/lnd/pull/10368) + * Finalize SQL payments implementation [enabling unit and itests + for SQL backend](https://github.com/lightningnetwork/lnd/pull/10292) ## Code Health diff --git a/itest/lnd_payment_test.go b/itest/lnd_payment_test.go index 37aff05226..f683cd44a8 100644 --- a/itest/lnd_payment_test.go +++ b/itest/lnd_payment_test.go @@ -504,61 +504,86 @@ func testListPayments(ht *lntest.HarnessTest) { expected bool } - // Create test cases to check the timestamp filters. - createCases := func(createTimeSeconds uint64) []testCase { + // Create test cases with proper rounding for start and end dates. + createCases := func(startTimeSeconds, + endTimeSeconds uint64) []testCase { + return []testCase{ { // Use a start date same as the creation date - // should return us the item. + // (truncated) should return us the item. name: "exact start date", - startDate: createTimeSeconds, + startDate: startTimeSeconds, expected: true, }, { // Use an earlier start date should return us // the item. name: "earlier start date", - startDate: createTimeSeconds - 1, + startDate: startTimeSeconds - 1, expected: true, }, { // Use a future start date should return us // nothing. name: "future start date", - startDate: createTimeSeconds + 1, + startDate: startTimeSeconds + 1, expected: false, }, { // Use an end date same as the creation date - // should return us the item. + // (ceiling) should return us the item. name: "exact end date", - endDate: createTimeSeconds, + endDate: endTimeSeconds, expected: true, }, { // Use an end date in the future should return // us the item. name: "future end date", - endDate: createTimeSeconds + 1, + endDate: endTimeSeconds + 1, expected: true, }, { // Use an earlier end date should return us // nothing. - name: "earlier end date", - endDate: createTimeSeconds - 1, + name: "earlier end date", + // The native sql backend has a higher + // precision than the kv backend, the native sql + // backend uses microseconds, the kv backend + // when filtering uses seconds so we need to + // subtract 2 seconds to ensure the payment is + // not included. + // We could also truncate before inserting + // into the sql db but I rather relax this test + // here. + endDate: endTimeSeconds - 2, expected: false, }, } } - // Get the payment creation time in seconds. - paymentCreateSeconds := uint64( - p.CreationTimeNs / time.Second.Nanoseconds(), + // Get the payment creation time in seconds, using different approaches + // for start and end date comparisons to avoid rounding issues. + creationTime := time.Unix(0, p.CreationTimeNs) + + // For start date comparisons: use truncation (floor) to include + // payments from the beginning of that second. + paymentCreateSecondsStart := uint64( + creationTime.Truncate(time.Second).Unix(), + ) + + // For end date comparisons: use ceiling to include payments up to the + // end of that second. + paymentCreateSecondsEnd := uint64( + (p.CreationTimeNs + time.Second.Nanoseconds() - 1) / + time.Second.Nanoseconds(), ) // Create test cases from the payment creation time. - testCases := createCases(paymentCreateSeconds) + testCases := createCases( + paymentCreateSecondsStart, paymentCreateSecondsEnd, + ) // We now check the timestamp filters in `ListPayments`. for _, tc := range testCases { @@ -578,7 +603,9 @@ func testListPayments(ht *lntest.HarnessTest) { } // Create test cases from the invoice creation time. - testCases = createCases(uint64(invoice.CreationDate)) + testCases = createCases( + uint64(invoice.CreationDate), uint64(invoice.CreationDate), + ) // We now do the same check for `ListInvoices`. for _, tc := range testCases { diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index d19127f32b..bb2f0800ca 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1762,6 +1762,8 @@ func (r *RouterBackend) MarshallPayment(payment *paymentsdb.MPPayment) ( // If any of the htlcs have settled, extract a valid // preimage. if htlc.Settle != nil { + // For AMP payments all hashes will be different so we + // will depict the last htlc preimage. preimage = htlc.Settle.Preimage fee += htlc.Route.TotalFees() } diff --git a/payments/db/errors.go b/payments/db/errors.go index fee71b05f5..ad4a71b211 100644 --- a/payments/db/errors.go +++ b/payments/db/errors.go @@ -83,6 +83,11 @@ var ( // paths or vice versa. ErrMixedBlindedAndNonBlindedPayments = errors.New("mixed blinded and " + "non-blinded payments") + // ErrBlindedPaymentMissingTotalAmount is returned if we try to + // register a blinded payment attempt where the final hop doesn't set + // the total amount. + ErrBlindedPaymentMissingTotalAmount = errors.New("blinded payment " + + "final hop must set total amount") // ErrMPPPaymentAddrMismatch is returned if we try to register an MPP // shard where the payment address doesn't match existing shards. diff --git a/payments/db/interface.go b/payments/db/interface.go index c41dc371f8..7fefad0891 100644 --- a/payments/db/interface.go +++ b/payments/db/interface.go @@ -61,6 +61,17 @@ type PaymentControl interface { InitPayment(lntypes.Hash, *PaymentCreationInfo) error // RegisterAttempt atomically records the provided HTLCAttemptInfo. + // + // IMPORTANT: Callers MUST serialize calls to RegisterAttempt for the + // same payment hash. Concurrent calls will result in race conditions + // where both calls read the same initial payment state, validate + // against stale data, and could cause overpayment. For example: + // - Both goroutines fetch payment with 400 sats sent + // - Both validate sending 650 sats won't overpay (within limit) + // - Both commit successfully + // - Result: 1700 sats sent, exceeding the payment amount + // The payment router/controller layer is responsible for ensuring + // serialized access per payment hash. RegisterAttempt(lntypes.Hash, *HTLCAttemptInfo) (*MPPayment, error) // SettleAttempt marks the given attempt settled with the preimage. If diff --git a/payments/db/kv_store.go b/payments/db/kv_store.go index 62f0b83867..84946841b9 100644 --- a/payments/db/kv_store.go +++ b/payments/db/kv_store.go @@ -291,6 +291,8 @@ func (p *KVStore) InitPayment(paymentHash lntypes.Hash, // DeleteFailedAttempts deletes all failed htlcs for a payment if configured // by the KVStore db. func (p *KVStore) DeleteFailedAttempts(hash lntypes.Hash) error { + // TODO(ziggie): Refactor to not mix application logic with database + // logic. This decision should be made in the application layer. if !p.keepFailedPaymentAttempts { const failedHtlcsOnly = true err := p.DeletePayment(hash, failedHtlcsOnly) diff --git a/payments/db/kv_store_test.go b/payments/db/kv_store_test.go index 2c2895175a..f0c2b148fd 100644 --- a/payments/db/kv_store_test.go +++ b/payments/db/kv_store_test.go @@ -1,7 +1,10 @@ +//go:build !test_db_sqlite && !test_db_postgres + package paymentsdb import ( "bytes" + "crypto/sha256" "encoding/binary" "io" "math" @@ -17,7 +20,6 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -65,10 +67,15 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { var numSuccess, numInflight int for _, p := range payments { - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash( + t, 0, genSessionKey(t), rhash, + ) + require.NoError(t, err) // Sends base htlc message which initiate StatusInFlight. err = paymentDB.InitPayment(info.PaymentIdentifier, info) @@ -246,83 +253,6 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { require.Equal(t, 1, indexCount) } -type htlcStatus struct { - *HTLCAttemptInfo - settle *lntypes.Preimage - failure *HTLCFailReason -} - -// fetchPaymentIndexEntry gets the payment hash for the sequence number provided -// from our payment indexes bucket. -func fetchPaymentIndexEntry(t *testing.T, p *KVStore, - sequenceNumber uint64) (*lntypes.Hash, error) { - - t.Helper() - - var hash lntypes.Hash - - if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { - indexBucket := tx.ReadBucket(paymentsIndexBucket) - key := make([]byte, 8) - byteOrder.PutUint64(key, sequenceNumber) - - indexValue := indexBucket.Get(key) - if indexValue == nil { - return ErrNoSequenceNrIndex - } - - r := bytes.NewReader(indexValue) - - var err error - hash, err = deserializePaymentIndex(r) - - return err - }, func() { - hash = lntypes.Hash{} - }); err != nil { - return nil, err - } - - return &hash, nil -} - -// assertPaymentIndex looks up the index for a payment in the db and checks -// that its payment hash matches the expected hash passed in. -func assertPaymentIndex(t *testing.T, p DB, expectedHash lntypes.Hash) { - t.Helper() - - // Only the kv implementation uses the index so we exit early if the - // payment db is not a kv implementation. This helps us to reuse the - // same test for both implementations. - kvPaymentDB, ok := p.(*KVStore) - if !ok { - return - } - - // Lookup the payment so that we have its sequence number and check - // that is has correctly been indexed in the payment indexes bucket. - pmt, err := kvPaymentDB.FetchPayment(expectedHash) - require.NoError(t, err) - - hash, err := fetchPaymentIndexEntry(t, kvPaymentDB, pmt.SequenceNum) - require.NoError(t, err) - assert.Equal(t, expectedHash, *hash) -} - -// assertNoIndex checks that an index for the sequence number provided does not -// exist. -func assertNoIndex(t *testing.T, p DB, seqNr uint64) { - t.Helper() - - kvPaymentDB, ok := p.(*KVStore) - if !ok { - return - } - - _, err := fetchPaymentIndexEntry(t, kvPaymentDB, seqNr) - require.Equal(t, ErrNoSequenceNrIndex, err) -} - func makeFakeInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo) { @@ -478,7 +408,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { paymentDB := NewKVTestDB(t) // Generate a test payment which does not have duplicates. - noDuplicates, _, _, err := genInfo(t) + noDuplicates, _, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -494,7 +424,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { require.NoError(t, err) // Generate a test payment which we will add duplicates to. - hasDuplicates, _, preimg, err := genInfo(t) + hasDuplicates, preimg, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -652,7 +582,7 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) // Generate fake information for the duplicate payment. - info, _, _, err := genInfo(t) + info, _, err := genInfo(t) require.NoError(t, err) // Write the payment info to disk under the creation info key. This code @@ -684,17 +614,19 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) } -// TestQueryPayments tests retrieval of payments with forwards and reversed -// queries. -// -// TODO(ziggie): Make this test db agnostic. -func TestQueryPayments(t *testing.T) { - // Define table driven test for QueryPayments. +// TestKVStoreQueryPaymentsDuplicates tests the KV store's legacy duplicate +// payment handling. This tests the specific case where duplicate payments +// are stored in a nested bucket within the parent payment bucket. +func TestKVStoreQueryPaymentsDuplicates(t *testing.T) { + t.Parallel() + // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. // Note that the payment with index 7 has the same payment hash as 6, // and is stored in a nested bucket within payment 6 rather than being - // its own entry in the payments bucket. We do this to test retrieval - // of legacy payments. + // its own entry in the payments bucket. This tests retrieval of legacy + // duplicate payments which is KV-store specific. + // These test cases focus on validating that duplicate payments (seq 7, + // nested under payment 6) are correctly returned in queries. tests := []struct { name string query Query @@ -706,31 +638,20 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs []uint64 }{ { - name: "IndexOffset at the end of the payments range", - query: Query{ - IndexOffset: 7, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "query in forwards order, start at beginning", + name: "query includes duplicate payment in forward " + + "order", query: Query{ - IndexOffset: 0, - MaxPayments: 2, + IndexOffset: 5, + MaxPayments: 3, Reversed: false, IncludeIncomplete: true, }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, + firstIndex: 6, + lastIndex: 7, + expectedSeqNrs: []uint64{6, 7}, }, { - name: "query in forwards order, start at end, overflow", + name: "query duplicate payment at end", query: Query{ IndexOffset: 6, MaxPayments: 2, @@ -742,44 +663,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{7}, }, { - name: "start at offset index outside of payments", - query: Query{ - IndexOffset: 20, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "overflow in forwards order", - query: Query{ - IndexOffset: 4, - MaxPayments: math.MaxUint64, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 5, - lastIndex: 7, - expectedSeqNrs: []uint64{5, 6, 7}, - }, - { - name: "start at offset index outside of payments, " + - "reversed order", - query: Query{ - IndexOffset: 9, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 6, - lastIndex: 7, - expectedSeqNrs: []uint64{6, 7}, - }, - { - name: "query in reverse order, start at end", + name: "query includes duplicate in reverse order", query: Query{ IndexOffset: 0, MaxPayments: 2, @@ -791,36 +675,11 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{6, 7}, }, { - name: "query in reverse order, starting in middle", - query: Query{ - IndexOffset: 4, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "query in reverse order, starting in middle, " + - "with underflow", - query: Query{ - IndexOffset: 4, - MaxPayments: 5, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "all payments in reverse, order maintained", + name: "query all payments includes duplicate", query: Query{ IndexOffset: 0, - MaxPayments: 7, - Reversed: true, + MaxPayments: math.MaxUint64, + Reversed: false, IncludeIncomplete: true, }, firstIndex: 1, @@ -828,7 +687,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{1, 3, 4, 5, 6, 7}, }, { - name: "exclude incomplete payments", + name: "exclude incomplete includes duplicate", query: Query{ IndexOffset: 0, MaxPayments: 7, @@ -839,96 +698,6 @@ func TestQueryPayments(t *testing.T) { lastIndex: 7, expectedSeqNrs: []uint64{7}, }, - { - name: "query payments at index gap", - query: Query{ - IndexOffset: 1, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 7, - expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "query payments reverse before index gap", - query: Query{ - IndexOffset: 3, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments reverse on index gap", - query: Query{ - IndexOffset: 2, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments forward on index gap", - query: Query{ - IndexOffset: 2, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 4, - expectedSeqNrs: []uint64{3, 4}, - }, - { - name: "query in forwards order, with start creation " + - "time", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 5, - }, - firstIndex: 5, - lastIndex: 6, - expectedSeqNrs: []uint64{5, 6}, - }, - { - name: "query in forwards order, with start creation " + - "time at end, overflow", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 7, - }, - firstIndex: 7, - lastIndex: 7, - expectedSeqNrs: []uint64{7}, - }, - { - name: "query with start and end creation time", - query: Query{ - IndexOffset: 9, - MaxPayments: math.MaxUint64, - Reversed: true, - IncludeIncomplete: true, - CreationDateStart: 3, - CreationDateEnd: 5, - }, - firstIndex: 3, - lastIndex: 5, - expectedSeqNrs: []uint64{3, 4, 5}, - }, } for _, tt := range tests { @@ -960,7 +729,7 @@ func TestQueryPayments(t *testing.T) { for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. - info, _, preimg, err := genInfo(t) + info, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to create test "+ "payment: %v", err) diff --git a/payments/db/payment.go b/payments/db/payment.go index 147ccdb1e7..ddceedfb0f 100644 --- a/payments/db/payment.go +++ b/payments/db/payment.go @@ -744,6 +744,13 @@ func verifyAttempt(payment *MPPayment, attempt *HTLCAttemptInfo) error { // in the split payment is correct. isBlinded := len(attempt.Route.FinalHop().EncryptedData) != 0 + // For blinded payments, the last hop must set the total amount. + if isBlinded { + if attempt.Route.FinalHop().TotalAmtMsat == 0 { + return ErrBlindedPaymentMissingTotalAmount + } + } + // Make sure any existing shards match the new one with regards // to MPP options. mpp := attempt.Route.FinalHop().MPP diff --git a/payments/db/payment_test.go b/payments/db/payment_test.go index a7369c14b8..55dabe69a0 100644 --- a/payments/db/payment_test.go +++ b/payments/db/payment_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math" "reflect" "testing" "time" @@ -58,7 +59,10 @@ var ( ChannelID: 12345, OutgoingTimeLock: 111, AmtToForward: 555, - LegacyPayload: true, + + // Only tlv payloads are now supported in LND therefore we set + // LegacyPayload to false. + LegacyPayload: false, } testRoute = route.Route{ @@ -99,6 +103,14 @@ var ( } ) +// htlcStatus is a helper structure used in tests to track the status of an HTLC +// attempt, including whether it was settled or failed. +type htlcStatus struct { + *HTLCAttemptInfo + settle *lntypes.Preimage + failure *HTLCFailReason +} + // payment is a helper structure that holds basic information on a test payment, // such as the payment id, the status and the total number of HTLCs attempted. type payment struct { @@ -116,13 +128,20 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { attemptID := uint64(0) for i := 0; i < len(payments); i++ { - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Set the payment id accordingly in the payments slice. payments[i].id = info.PaymentIdentifier - attempt.AttemptID = attemptID + attempt, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + attemptID++ // Init the payment. @@ -148,7 +167,10 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { // Depending on the test case, fail or succeed the next // attempt. - attempt.AttemptID = attemptID + attempt, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) attemptID++ _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) @@ -334,7 +356,7 @@ func assertDBPayments(t *testing.T, paymentDB DB, payments []*payment) { } // genPreimage generates a random preimage. -func genPreimage(t *testing.T) ([32]byte, error) { +func genPreimage(t *testing.T) (lntypes.Preimage, error) { t.Helper() var preimage [32]byte @@ -345,31 +367,85 @@ func genPreimage(t *testing.T) ([32]byte, error) { return preimage, nil } -// genInfo generates a payment creation info, an attempt info and a preimage. -func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, - lntypes.Preimage, error) { +// genSessionKey generates a new random private key for use as a session key. +func genSessionKey(t *testing.T) *btcec.PrivateKey { + t.Helper() - preimage, err := genPreimage(t) - if err != nil { - return nil, nil, preimage, fmt.Errorf("unable to "+ - "generate preimage: %v", err) + key, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return key +} + +// genPaymentCreationInfo generates a payment creation info. +func genPaymentCreationInfo(t *testing.T, + paymentHash lntypes.Hash) *PaymentCreationInfo { + + t.Helper() + + // Add constant first hop custom records for testing for testing + // purposes. + firstHopCustomRecords := lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 1: []byte("test_record_1"), + lnwire.MinCustomRecordsTlvType + 2: []byte("test_record_2"), + lnwire.MinCustomRecordsTlvType + 3: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, + }, } + return &PaymentCreationInfo{ + PaymentIdentifier: paymentHash, + Value: testRoute.ReceiverAmt(), + CreationTime: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + FirstHopCustomRecords: firstHopCustomRecords, + } +} + +// genPreimageAndHash generates a random preimage and its corresponding hash. +func genPreimageAndHash(t *testing.T) (lntypes.Preimage, lntypes.Hash, error) { + t.Helper() + + preimage, err := genPreimage(t) + require.NoError(t, err) + rhash := sha256.Sum256(preimage[:]) var hash lntypes.Hash copy(hash[:], rhash[:]) + return preimage, hash, nil +} + +// genAttemptWithPreimage generates an HTLC attempt and returns both the +// attempt and preimage. +func genAttemptWithHash(t *testing.T, attemptID uint64, + sessionKey *btcec.PrivateKey, hash lntypes.Hash) (*HTLCAttemptInfo, + error) { + + t.Helper() + attempt, err := NewHtlcAttempt( - 0, priv, *testRoute.Copy(), time.Time{}, &hash, + attemptID, sessionKey, *testRoute.Copy(), time.Time{}, + &hash, ) - require.NoError(t, err) + if err != nil { + return nil, err + } - return &PaymentCreationInfo{ - PaymentIdentifier: rhash, - Value: testRoute.ReceiverAmt(), - CreationTime: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("hola"), - }, &attempt.HTLCAttemptInfo, preimage, nil + return &attempt.HTLCAttemptInfo, nil +} + +// genInfo generates a payment creation info and the corresponding preimage. +func genInfo(t *testing.T) (*PaymentCreationInfo, lntypes.Preimage, error) { + preimage, _, err := genPreimageAndHash(t) + if err != nil { + return nil, preimage, err + } + + rhash := sha256.Sum256(preimage[:]) + creationInfo := genPaymentCreationInfo(t, rhash) + + return creationInfo, preimage, nil } // TestDeleteFailedAttempts checks that DeleteFailedAttempts properly removes @@ -388,7 +464,7 @@ func TestDeleteFailedAttempts(t *testing.T) { // testDeleteFailedAttempts tests the DeleteFailedAttempts method with the // given keepFailedPaymentAttempts flag as argument. func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { - paymentDB := NewTestDB( + paymentDB, _ := NewTestDB( t, WithKeepFailedPaymentAttempts(keepFailedPaymentAttempts), ) @@ -479,9 +555,19 @@ func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { func TestMPPRecordValidation(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, _ := NewTestDB(t) - info, attempt, _, err := genInfo(t) + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + + attemptID := uint64(0) + + attempt, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) require.NoError(t, err, "unable to generate htlc message") // Init the payment. @@ -502,29 +588,45 @@ func TestMPPRecordValidation(t *testing.T) { require.NoError(t, err, "unable to send htlc message") // Now try to register a non-MPP attempt, which should fail. - b := *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = nil - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + attemptID++ + attempt2, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + attempt2.Route.FinalHop().MPP = nil + + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrMPPayment) // Try to register attempt one with a different payment address. - b.Route.FinalHop().MPP = record.NewMPP( + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{2}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrMPPPaymentAddrMismatch) // Try registering one with a different total amount. - b.Route.FinalHop().MPP = record.NewMPP( + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value/2, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrMPPTotalAmountMismatch) // Create and init a new payment. This time we'll check that we cannot // register an MPP attempt if we already registered a non-MPP one. - info, attempt, _, err = genInfo(t) + preimg, err = genPreimage(t) + require.NoError(t, err) + + rhash = sha256.Sum256(preimg[:]) + info = genPaymentCreationInfo(t, rhash) + + attemptID++ + attempt, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + require.NoError(t, err, "unable to generate htlc message") err = paymentDB.InitPayment(info.PaymentIdentifier, info) @@ -535,13 +637,17 @@ func TestMPPRecordValidation(t *testing.T) { require.NoError(t, err, "unable to send htlc message") // Attempt to register an MPP attempt, which should fail. - b = *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = record.NewMPP( + attemptID++ + attempt2, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrNonMPPayment) } @@ -550,7 +656,7 @@ func TestMPPRecordValidation(t *testing.T) { func TestDeleteSinglePayment(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, _ := NewTestDB(t) // Register four payments: // All payments will have one failed HTLC attempt and one HTLC attempt @@ -1454,10 +1560,13 @@ func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { func TestSuccessesWithoutInFlight(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, _ := NewTestDB(t) - info, _, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Attempt to complete the payment should fail. _, err = paymentDB.SettleAttempt( @@ -1474,10 +1583,13 @@ func TestSuccessesWithoutInFlight(t *testing.T) { func TestFailsWithoutInFlight(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, _ := NewTestDB(t) - info, _, _, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Calling Fail should return an error. _, err = paymentDB.Fail( @@ -1491,7 +1603,7 @@ func TestFailsWithoutInFlight(t *testing.T) { func TestDeletePayments(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, _ := NewTestDB(t) // Register three payments: // 1. A payment with two failed attempts. @@ -1549,17 +1661,22 @@ func TestDeletePayments(t *testing.T) { func TestSwitchDoubleSend(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, harness := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash(t, 0, genSessionKey(t), rhash) + require.NoError(t, err) // Sends base htlc message which initiate base status and move it to // StatusInFlight and verifies that it was changed. err = paymentDB.InitPayment(info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + harness.AssertPaymentIndex(t, info.PaymentIdentifier) assertDBPaymentstatus( t, paymentDB, info.PaymentIdentifier, StatusInitiated, ) @@ -1622,16 +1739,21 @@ func TestSwitchDoubleSend(t *testing.T) { func TestSwitchFail(t *testing.T) { t.Parallel() - paymentDB := NewTestDB(t) + paymentDB, harness := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash(t, 0, genSessionKey(t), rhash) + require.NoError(t, err) // Sends base htlc message which initiate StatusInFlight. err = paymentDB.InitPayment(info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + harness.AssertPaymentIndex(t, info.PaymentIdentifier) assertDBPaymentstatus( t, paymentDB, info.PaymentIdentifier, StatusInitiated, ) @@ -1665,8 +1787,8 @@ func TestSwitchFail(t *testing.T) { // Check that our index has been updated, and the old index has been // removed. - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) - assertNoIndex(t, paymentDB, payment.SequenceNum) + harness.AssertPaymentIndex(t, info.PaymentIdentifier) + harness.AssertNoIndex(t, payment.SequenceNum) assertDBPaymentstatus( t, paymentDB, info.PaymentIdentifier, StatusInitiated, @@ -1703,7 +1825,11 @@ func TestSwitchFail(t *testing.T) { assertPaymentInfo(t, paymentDB, info.PaymentIdentifier, info, nil, htlc) // Record another attempt. - attempt.AttemptID = 1 + attempt, err = genAttemptWithHash( + t, 1, genSessionKey(t), rhash, + ) + require.NoError(t, err) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") assertDBPaymentstatus( @@ -1779,20 +1905,19 @@ func TestMultiShard(t *testing.T) { } runSubTest := func(t *testing.T, test testCase) { - paymentDB := NewTestDB(t) + paymentDB, harness := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Init the payment, moving it to the StatusInFlight state. err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } + require.NoError(t, err) - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + harness.AssertPaymentIndex(t, info.PaymentIdentifier) assertDBPaymentstatus( t, paymentDB, info.PaymentIdentifier, StatusInitiated, ) @@ -1805,19 +1930,23 @@ func TestMultiShard(t *testing.T) { // attempts's value to one third of the payment amount, and // populate the MPP options. shardAmt := info.Value / 3 - attempt.Route.FinalHop().AmtToForward = shardAmt - attempt.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{1}, - ) var attempts []*HTLCAttemptInfo for i := uint64(0); i < 3; i++ { - a := *attempt - a.AttemptID = i - attempts = append(attempts, &a) + a, err := genAttemptWithHash( + t, i, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + a.Route.FinalHop().AmtToForward = shardAmt + a.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + attempts = append(attempts, a) _, err = paymentDB.RegisterAttempt( - info.PaymentIdentifier, &a, + info.PaymentIdentifier, a, ) if err != nil { t.Fatalf("unable to send htlc message: %v", err) @@ -1828,7 +1957,7 @@ func TestMultiShard(t *testing.T) { ) htlc := &htlcStatus{ - HTLCAttemptInfo: &a, + HTLCAttemptInfo: a, } assertPaymentInfo( t, paymentDB, info.PaymentIdentifier, info, nil, @@ -1839,9 +1968,17 @@ func TestMultiShard(t *testing.T) { // For a fourth attempt, check that attempting to // register it will fail since the total sent amount // will be too large. - b := *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + b, err := genAttemptWithHash( + t, 3, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + b.Route.FinalHop().AmtToForward = shardAmt + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, b) require.ErrorIs(t, err, ErrValueExceedsAmt) // Fail the second attempt. @@ -1938,9 +2075,17 @@ func TestMultiShard(t *testing.T) { // Try to register yet another attempt. This should fail now // that the payment has reached a terminal condition. - b = *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + b, err = genAttemptWithHash( + t, 3, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + b.Route.FinalHop().AmtToForward = shardAmt + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, b) if test.settleFirst { require.ErrorIs( t, err, ErrPaymentPendingSettled, @@ -2039,8 +2184,8 @@ func TestMultiShard(t *testing.T) { ) // Finally assert we cannot register more attempts. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.Equal(t, registerErr, err) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, b) + require.ErrorIs(t, err, registerErr) } for _, test := range tests { @@ -2052,3 +2197,477 @@ func TestMultiShard(t *testing.T) { }) } } + +// TestQueryPayments tests retrieval of payments with forwards and reversed +// queries. +func TestQueryPayments(t *testing.T) { + // Define table driven test for QueryPayments. + // Test payments have sequence indices [1, 3, 4, 5, 6]. + // Note that payment with index 2 is deleted to create a gap in the + // sequence numbers. + tests := []struct { + name string + query Query + firstIndex uint64 + lastIndex uint64 + + // expectedSeqNrs contains the set of sequence numbers we expect + // our query to return. + expectedSeqNrs []uint64 + }{ + { + name: "IndexOffset at the end of the payments range", + query: Query{ + IndexOffset: 6, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "query in forwards order, start at beginning", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in forwards order, start at end, overflow", + query: Query{ + IndexOffset: 5, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "start at offset index outside of payments", + query: Query{ + IndexOffset: 20, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "overflow in forwards order", + query: Query{ + IndexOffset: 4, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "start at offset index outside of payments, " + + "reversed order", + query: Query{ + IndexOffset: 9, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in reverse order, start at end", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in reverse order, starting in middle", + query: Query{ + IndexOffset: 4, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in reverse order, starting in middle, " + + "with underflow", + query: Query{ + IndexOffset: 4, + MaxPayments: 5, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "all payments in reverse, order maintained", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 6, + expectedSeqNrs: []uint64{1, 3, 4, 5, 6}, + }, + { + name: "exclude incomplete payments", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: false, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "query payments at index gap", + query: Query{ + IndexOffset: 1, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 6, + expectedSeqNrs: []uint64{3, 4, 5, 6}, + }, + { + name: "query payments reverse before index gap", + query: Query{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, + { + name: "query in forwards order, with start creation " + + "time", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 5, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in forwards order, with start creation " + + "time at end, overflow", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 6, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "query with start and end creation time", + query: Query{ + IndexOffset: 9, + MaxPayments: math.MaxUint64, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + CreationDateEnd: 5, + }, + firstIndex: 3, + lastIndex: 5, + expectedSeqNrs: []uint64{3, 4, 5}, + }, + { + name: "query with only end creation time", + query: Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + CreationDateEnd: 4, + }, + firstIndex: 1, + lastIndex: 4, + expectedSeqNrs: []uint64{1, 3, 4}, + }, + { + name: "query reversed with creation date start", + query: Query{ + IndexOffset: 0, + MaxPayments: 3, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + }, + firstIndex: 4, + lastIndex: 6, + expectedSeqNrs: []uint64{4, 5, 6}, + }, + { + name: "count total with forward pagination", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CountTotal: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "count total with reverse pagination", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + CountTotal: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "count total with filters", + query: Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: false, + CountTotal: true, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + paymentDB, harness := NewTestDB(t) + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := paymentDB.QueryPayments(ctx, tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) + + // Populate the database with a set of test payments. + // We create 6 payments, deleting the payment at index + // 2 so that we cover the case where sequence numbers + // are missing. + numberOfPayments := 6 + + // Store payment info for all payments so we can delete + // one after all are created. + var paymentInfos []*PaymentCreationInfo + + // First, create all payments. + for i := range numberOfPayments { + // Generate a test payment. + info, _, err := genInfo(t) + require.NoError(t, err) + + // Override creation time to allow for testing + // of CreationDateStart and CreationDateEnd. + info.CreationTime = time.Unix(int64(i+1), 0) + + paymentInfos = append(paymentInfos, info) + + // Create a new payment entry in the database. + err = paymentDB.InitPayment( + info.PaymentIdentifier, info, + ) + require.NoError(t, err) + } + + // Now delete the payment at index 1 (the second + // payment). + pmt, err := paymentDB.FetchPayment( + paymentInfos[1].PaymentIdentifier, + ) + require.NoError(t, err) + + // We delete the whole payment. + err = paymentDB.DeletePayment( + paymentInfos[1].PaymentIdentifier, false, + ) + require.NoError(t, err) + + // Verify the payment is deleted. + _, err = paymentDB.FetchPayment( + paymentInfos[1].PaymentIdentifier, + ) + require.ErrorIs( + t, err, ErrPaymentNotInitiated, + ) + + // Verify the index is removed (KV store only). + harness.AssertNoIndex( + t, pmt.SequenceNum, + ) + + // For the last payment, settle it so we have at least + // one completed payment for the "exclude incomplete" + // test case. + lastPaymentInfo := paymentInfos[numberOfPayments-1] + attempt, err := NewHtlcAttempt( + 1, priv, testRoute, + time.Unix(100, 0), + &lastPaymentInfo.PaymentIdentifier, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt( + lastPaymentInfo.PaymentIdentifier, + &attempt.HTLCAttemptInfo, + ) + require.NoError(t, err) + + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + _, err = paymentDB.SettleAttempt( + lastPaymentInfo.PaymentIdentifier, + attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err) + + // Fetch all payments in the database. + resp, err = paymentDB.QueryPayments( + ctx, Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + IncludeIncomplete: true, + }, + ) + require.NoError(t, err) + + allPayments := resp.Payments + + if len(allPayments) != 5 { + t.Fatalf("Number of payments received does "+ + "not match expected one. Got %v, "+ + "want %v.", len(allPayments), 5) + } + + querySlice, err := paymentDB.QueryPayments( + ctx, tt.query, + ) + require.NoError(t, err) + + if tt.firstIndex != querySlice.FirstIndexOffset || + tt.lastIndex != querySlice.LastIndexOffset { + + t.Errorf("First or last index does not match "+ + "expected index. Want (%d, %d), "+ + "got (%d, %d).", + tt.firstIndex, tt.lastIndex, + querySlice.FirstIndexOffset, + querySlice.LastIndexOffset) + } + + if len(querySlice.Payments) != len(tt.expectedSeqNrs) { + t.Errorf("expected: %v payments, got: %v", + len(tt.expectedSeqNrs), + len(querySlice.Payments)) + } + + for i, seqNr := range tt.expectedSeqNrs { + q := querySlice.Payments[i] + if seqNr != q.SequenceNum { + t.Errorf("sequence numbers do not "+ + "match, got %v, want %v", + q.SequenceNum, seqNr) + } + } + + // Verify CountTotal is set correctly when requested. + if tt.query.CountTotal { + // We should have 5 total payments + // (6 created - 1 deleted). + expectedTotal := uint64(5) + if querySlice.TotalCount != expectedTotal { + t.Errorf("expected total count %v, "+ + "got %v", expectedTotal, + querySlice.TotalCount) + } + } else if querySlice.TotalCount != 0 { + t.Errorf("expected total count 0 when "+ + "CountTotal=false, got %v", + querySlice.TotalCount) + } + }) + } +} diff --git a/payments/db/sql_converters.go b/payments/db/sql_converters.go index fd0cad2dcd..66f3b1d3ad 100644 --- a/payments/db/sql_converters.go +++ b/payments/db/sql_converters.go @@ -27,8 +27,10 @@ func dbPaymentToCreationInfo(paymentIdentifier []byte, amountMsat int64, copy(identifier[:], paymentIdentifier) return &PaymentCreationInfo{ - PaymentIdentifier: identifier, - Value: lnwire.MilliSatoshi(amountMsat), + PaymentIdentifier: identifier, + Value: lnwire.MilliSatoshi(amountMsat), + // The creation time is stored in the database as UTC but here + // we convert it to local time. CreationTime: createdAt.Local(), PaymentRequest: intentPayload, FirstHopCustomRecords: firstHopCustomRecords, @@ -205,7 +207,8 @@ func dbDataToRoute(hops []sqlc.FetchHopsForAttemptsRow, ) } - // Add blinding point if present (only for introduction node). + // Add blinding point if present (only for introduction node + // in blinded route). if len(hop.BlindingPoint) > 0 { pubKey, err := btcec.ParsePubKey(hop.BlindingPoint) if err != nil { diff --git a/payments/db/sql_store.go b/payments/db/sql_store.go index 1b7bfbacb7..47a7177986 100644 --- a/payments/db/sql_store.go +++ b/payments/db/sql_store.go @@ -1,15 +1,18 @@ package paymentsdb import ( + "bytes" "context" "database/sql" "errors" "fmt" "math" + "strconv" "time" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" "github.com/lightningnetwork/lnd/sqldb/sqlc" ) @@ -37,24 +40,51 @@ const ( // SQLQueries is a subset of the sqlc.Querier interface that can be used to // execute queries against the SQL payments tables. // -//nolint:ll +//nolint:ll,interfacebloat type SQLQueries interface { /* Payment DB read operations. */ FilterPayments(ctx context.Context, query sqlc.FilterPaymentsParams) ([]sqlc.FilterPaymentsRow, error) FetchPayment(ctx context.Context, paymentIdentifier []byte) (sqlc.FetchPaymentRow, error) - FetchPaymentsByIDs(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchPaymentsByIDsRow, error) CountPayments(ctx context.Context) (int64, error) FetchHtlcAttemptsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptsForPaymentsRow, error) - FetchAllInflightAttempts(ctx context.Context) ([]sqlc.PaymentHtlcAttempt, error) + FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptResolutionsForPaymentsRow, error) + FetchAllInflightAttempts(ctx context.Context, arg sqlc.FetchAllInflightAttemptsParams) ([]sqlc.FetchAllInflightAttemptsRow, error) FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.FetchHopsForAttemptsRow, error) FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentIDs []int64) ([]sqlc.PaymentFirstHopCustomRecord, error) FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.PaymentAttemptFirstHopCustomRecord, error) FetchHopLevelCustomRecords(ctx context.Context, hopIDs []int64) ([]sqlc.PaymentHopCustomRecord, error) + + /* + Payment DB write operations. + */ + InsertPaymentIntent(ctx context.Context, arg sqlc.InsertPaymentIntentParams) (int64, error) + InsertPayment(ctx context.Context, arg sqlc.InsertPaymentParams) (int64, error) + InsertPaymentFirstHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentFirstHopCustomRecordParams) error + + InsertHtlcAttempt(ctx context.Context, arg sqlc.InsertHtlcAttemptParams) (int64, error) + InsertRouteHop(ctx context.Context, arg sqlc.InsertRouteHopParams) (int64, error) + InsertRouteHopMpp(ctx context.Context, arg sqlc.InsertRouteHopMppParams) error + InsertRouteHopAmp(ctx context.Context, arg sqlc.InsertRouteHopAmpParams) error + InsertRouteHopBlinded(ctx context.Context, arg sqlc.InsertRouteHopBlindedParams) error + + InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentAttemptFirstHopCustomRecordParams) error + InsertPaymentHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentHopCustomRecordParams) error + + SettleAttempt(ctx context.Context, arg sqlc.SettleAttemptParams) error + FailAttempt(ctx context.Context, arg sqlc.FailAttemptParams) error + + FailPayment(ctx context.Context, arg sqlc.FailPaymentParams) (sql.Result, error) + + DeletePayment(ctx context.Context, paymentID int64) error + + // DeleteFailedAttempts removes all failed HTLCs from the db for a + // given payment. + DeleteFailedAttempts(ctx context.Context, paymentID int64) error } // BatchedSQLQueries is a version of the SQLQueries that's capable @@ -66,10 +96,6 @@ type BatchedSQLQueries interface { // SQLStore represents a storage backend. type SQLStore struct { - // TODO(ziggie): Remove the KVStore once all the interface functions are - // implemented. - KVStore - cfg *SQLStoreConfig db BatchedSQLQueries @@ -151,7 +177,8 @@ type paymentsBatchData struct { } // loadPaymentCustomRecords loads payment-level custom records for a given -// set of payment IDs. +// set of payment IDs. It uses a batch query to fetch all custom records for +// the given payment IDs. func (s *SQLStore) loadPaymentCustomRecords(ctx context.Context, db SQLQueries, paymentIDs []int64, batchData *paymentsBatchData) error { @@ -184,7 +211,8 @@ func (s *SQLStore) loadPaymentCustomRecords(ctx context.Context, } // loadHtlcAttempts loads HTLC attempts for all payments and returns all -// attempt indices. +// attempt indices. It uses a batch query to fetch all attempts for the given +// payment IDs. func (s *SQLStore) loadHtlcAttempts(ctx context.Context, db SQLQueries, paymentIDs []int64, batchData *paymentsBatchData) ([]int64, error) { @@ -216,6 +244,7 @@ func (s *SQLStore) loadHtlcAttempts(ctx context.Context, db SQLQueries, } // loadHopsForAttempts loads hops for all attempts and returns all hop IDs. +// It uses a batch query to fetch all hops for the given attempt indices. func (s *SQLStore) loadHopsForAttempts(ctx context.Context, db SQLQueries, attemptIndices []int64, batchData *paymentsBatchData) ([]int64, error) { @@ -247,7 +276,8 @@ func (s *SQLStore) loadHopsForAttempts(ctx context.Context, db SQLQueries, return hopIDs, err } -// loadHopCustomRecords loads hop-level custom records for all hops. +// loadHopCustomRecords loads hop-level custom records for all hops. It uses +// a batch query to fetch all custom records for the given hop IDs. func (s *SQLStore) loadHopCustomRecords(ctx context.Context, db SQLQueries, hopIDs []int64, batchData *paymentsBatchData) error { @@ -280,7 +310,8 @@ func (s *SQLStore) loadHopCustomRecords(ctx context.Context, db SQLQueries, } // loadRouteCustomRecords loads route-level first hop custom records for all -// attempts. +// attempts. It uses a batch query to fetch all custom records for the given +// attempt indices. func (s *SQLStore) loadRouteCustomRecords(ctx context.Context, db SQLQueries, attemptIndices []int64, batchData *paymentsBatchData) error { @@ -308,7 +339,110 @@ func (s *SQLStore) loadRouteCustomRecords(ctx context.Context, db SQLQueries, ) } +// paymentStatusBatchData holds lightweight resolution data for computing +// payment status efficiently during deletion operations. +type paymentStatusBatchData struct { + // resolutionTypes maps payment ID to a list of resolution types + // for that payment's HTLC attempts. + resolutionTypes map[int64][]sql.NullInt32 +} + +// loadPaymentResolutionsBatchData loads only HTLC resolution types for multiple +// payments. This is a lightweight alternative to loadPaymentsBatchData that's +// optimized for operations that only need to determine payment status. +func (s *SQLStore) loadPaymentResolutionsBatchData(ctx context.Context, + db SQLQueries, paymentIDs []int64) (*paymentStatusBatchData, error) { + + batchData := &paymentStatusBatchData{ + resolutionTypes: make(map[int64][]sql.NullInt32), + } + + if len(paymentIDs) == 0 { + return batchData, nil + } + + // Fetch resolution types for all payments in a single batch query. + resolutions, err := db.FetchHtlcAttemptResolutionsForPayments( + ctx, paymentIDs, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch HTLC resolutions: %w", + err) + } + + // Group resolutions by payment ID. + for _, res := range resolutions { + batchData.resolutionTypes[res.PaymentID] = append( + batchData.resolutionTypes[res.PaymentID], + res.ResolutionType, + ) + } + + return batchData, nil +} + +// loadPaymentResolutions is a single-payment wrapper around +// loadPaymentResolutionsBatchData for convenience and to prevent duplicate +// queries. +func (s *SQLStore) loadPaymentResolutions(ctx context.Context, db SQLQueries, + paymentID int64) ([]sql.NullInt32, error) { + + batchData, err := s.loadPaymentResolutionsBatchData( + ctx, db, []int64{paymentID}, + ) + if err != nil { + return nil, err + } + + return batchData.resolutionTypes[paymentID], nil +} + +// computePaymentStatus determines the payment status from resolution types +// and failure reason without building the complete MPPayment structure. +// This is a lightweight version that builds minimal HTLCAttempt structures +// and delegates to decidePaymentStatus for consistency. +func computePaymentStatus(resolutionTypes []sql.NullInt32, + failReason sql.NullInt32) (PaymentStatus, error) { + + // Build minimal HTLCAttempt slice with only resolution info. + htlcs := make([]HTLCAttempt, len(resolutionTypes)) + for i, resType := range resolutionTypes { + if !resType.Valid { + // NULL resolution_type means in-flight (no Settle, no + // Failure). + continue + } + + switch HTLCAttemptResolutionType(resType.Int32) { + case HTLCAttemptResolutionSettled: + // Mark as settled (preimage details not needed for + // status). + htlcs[i].Settle = &HTLCSettleInfo{} + + case HTLCAttemptResolutionFailed: + // Mark as failed (failure details not needed for + // status). + htlcs[i].Failure = &HTLCFailInfo{} + + default: + return 0, fmt.Errorf("unknown resolution type: %v", + resType.Int32) + } + } + + // Convert fail reason to FailureReason pointer. + var failureReason *FailureReason + if failReason.Valid { + reason := FailureReason(failReason.Int32) + failureReason = &reason + } + + // Use the existing status decision logic. + return decidePaymentStatus(htlcs, failureReason) +} + // loadPaymentsBatchData loads all related data for multiple payments in batch. +// It uses a batch queries to fetch all data for the given payment IDs. func (s *SQLStore) loadPaymentsBatchData(ctx context.Context, db SQLQueries, paymentIDs []int64) (*paymentsBatchData, error) { @@ -655,6 +789,24 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response, }, nil } +// fetchPaymentByHash fetches a payment by its hash from the database. It is a +// convenience wrapper around the FetchPayment method and checks for +// no rows error and returns ErrPaymentNotInitiated if no payment is found. +func fetchPaymentByHash(ctx context.Context, db SQLQueries, + paymentHash lntypes.Hash) (sqlc.FetchPaymentRow, error) { + + dbPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return dbPayment, fmt.Errorf("failed to fetch payment: %w", err) + } + + if errors.Is(err, sql.ErrNoRows) { + return dbPayment, ErrPaymentNotInitiated + } + + return dbPayment, nil +} + // FetchPayment retrieves a complete payment record from the database by its // payment hash. The returned MPPayment includes all payment metadata such as // creation info, payment status, current state, all HTLC attempts (both @@ -670,13 +822,9 @@ func (s *SQLStore) FetchPayment(paymentHash lntypes.Hash) (*MPPayment, error) { var mpPayment *MPPayment err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - dbPayment, err := db.FetchPayment(ctx, paymentHash[:]) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("failed to fetch payment: %w", err) - } - - if errors.Is(err, sql.ErrNoRows) { - return ErrPaymentNotInitiated + dbPayment, err := fetchPaymentByHash(ctx, db, paymentHash) + if err != nil { + return err } mpPayment, err = s.fetchPaymentWithCompleteData( @@ -695,3 +843,1054 @@ func (s *SQLStore) FetchPayment(paymentHash lntypes.Hash) (*MPPayment, error) { return mpPayment, nil } + +// FetchInFlightPayments retrieves all payments that have HTLC attempts +// currently in flight (not yet settled or failed). These are payments with at +// least one HTLC attempt that has been registered but has no resolution record. +// +// The SQLStore implementation provides a significant performance improvement +// over the KVStore implementation by using targeted SQL queries instead of +// scanning all payments. +// +// This method is part of the PaymentReader interface, which is embedded in the +// DB interface. It's typically called during node startup to resume monitoring +// of pending payments and ensure HTLCs are properly tracked. +// +// TODO(ziggie): Consider changing the interface to use a callback or iterator +// pattern instead of returning all payments at once. This would allow +// processing payments one at a time without holding them all in memory +// simultaneously: +// - Callback: func FetchInFlightPayments(ctx, func(*MPPayment) error) error +// - Iterator: func FetchInFlightPayments(ctx) (PaymentIterator, error) +// +// While inflight payments are typically a small subset, this would improve +// memory efficiency for nodes with unusually high numbers of concurrent +// payments and would better leverage the existing pagination infrastructure. +func (s *SQLStore) FetchInFlightPayments() ([]*MPPayment, + error) { + + ctx := context.TODO() + + var mpPayments []*MPPayment + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // Track which payment IDs we've already processed across all + // pages to avoid loading the same payment multiple times when + // multiple inflight attempts belong to the same payment. + processedPayments := make(map[int64]*MPPayment) + + extractCursor := func( + row sqlc.FetchAllInflightAttemptsRow) int64 { + + return row.AttemptIndex + } + + // collectFunc extracts the payment ID from each attempt row. + collectFunc := func(row sqlc.FetchAllInflightAttemptsRow) ( + int64, error) { + + return row.PaymentID, nil + } + + // batchDataFunc loads payment data for a batch of payment IDs, + // but only for IDs we haven't processed yet. + batchDataFunc := func(ctx context.Context, + paymentIDs []int64) (*paymentsBatchData, error) { + + // Filter out already-processed payment IDs. + uniqueIDs := make([]int64, 0, len(paymentIDs)) + for _, id := range paymentIDs { + _, processed := processedPayments[id] + if !processed { + uniqueIDs = append(uniqueIDs, id) + } + } + + // If uniqueIDs is empty, the batch load will return + // empty batch data. + return s.loadPaymentsBatchData(ctx, db, uniqueIDs) + } + + // processAttempt processes each attempt. We only build and + // store the payment once per unique payment ID. + processAttempt := func(ctx context.Context, + row sqlc.FetchAllInflightAttemptsRow, + batchData *paymentsBatchData) error { + + // Skip if we've already processed this payment. + _, processed := processedPayments[row.PaymentID] + if processed { + return nil + } + + // Extract payment record directly from the row. + //nolint:ll + dbPayment := sqlc.FetchPaymentRow{ + Payment: sqlc.Payment{ + ID: row.PaymentID, + AmountMsat: row.AmountMsat, + CreatedAt: row.CreatedAt, + PaymentIdentifier: row.PaymentIdentifier, + FailReason: row.FailReason, + }, + IntentType: row.IntentType, + IntentPayload: row.IntentPayload, + } + + // Build the payment from batch data. + mpPayment, err := s.buildPaymentFromBatchData( + dbPayment, batchData, + ) + if err != nil { + return fmt.Errorf("failed to build payment: %w", + err) + } + + // Store in our processed map. + processedPayments[row.PaymentID] = mpPayment + + return nil + } + + queryFunc := func(ctx context.Context, lastAttemptIndex int64, + limit int32) ([]sqlc.FetchAllInflightAttemptsRow, + error) { + + return db.FetchAllInflightAttempts(ctx, + sqlc.FetchAllInflightAttemptsParams{ + AttemptIndex: lastAttemptIndex, + Limit: limit, + }, + ) + } + + err := sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, s.cfg.QueryCfg, int64(-1), queryFunc, + extractCursor, collectFunc, batchDataFunc, + processAttempt, + ) + if err != nil { + return err + } + + // Convert map to slice. + mpPayments = make([]*MPPayment, 0, len(processedPayments)) + for _, payment := range processedPayments { + mpPayments = append(mpPayments, payment) + } + + return nil + }, func() { + mpPayments = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch inflight "+ + "payments: %w", err) + } + + return mpPayments, nil +} + +// DeleteFailedAttempts removes all failed HTLC attempts from the database for +// the specified payment, while preserving the payment record itself and any +// successful or in-flight attempts. +// +// The method performs the following validations before deletion: +// - StatusInitiated: Can delete failed attempts +// - StatusInFlight: Cannot delete, returns ErrPaymentInFlight (active HTLCs +// still on the network) +// - StatusSucceeded: Can delete failed attempts (payment completed) +// - StatusFailed: Can delete failed attempts (payment permanently failed) +// +// If the keepFailedPaymentAttempts configuration flag is enabled, this method +// returns immediately without deleting anything, allowing failed attempts to +// be retained for debugging or auditing purposes. +// +// This method is idempotent - calling it multiple times on the same payment +// has no adverse effects. +// +// This method is part of the PaymentControl interface, which is embedded in +// the PaymentWriter interface and ultimately the DB interface. It represents +// the final step (step 5) in the payment lifecycle control flow and should be +// called after a payment reaches a terminal state (succeeded or permanently +// failed) to clean up historical failed attempts. +func (s *SQLStore) DeleteFailedAttempts(paymentHash lntypes.Hash) error { + // In case we are configured to keep failed payment attempts, we exit + // early. + // + // TODO(ziggie): Refactor to not mix application logic with database + // logic. This decision should be made in the application layer. + if s.keepFailedPaymentAttempts { + return nil + } + + ctx := context.TODO() + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + dbPayment, err := fetchPaymentByHash(ctx, db, paymentHash) + if err != nil { + return err + } + + paymentStatus, err := s.computePaymentStatusFromDB( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to compute payment "+ + "status: %w", err) + } + + if err := paymentStatus.removable(); err != nil { + return fmt.Errorf("cannot delete failed "+ + "attempts for payment %v: %w", paymentHash, err) + } + + // Then we delete the failed attempts for this payment. + return db.DeleteFailedAttempts(ctx, dbPayment.GetPayment().ID) + }, sqldb.NoOpReset) + if err != nil { + return fmt.Errorf("failed to delete failed attempts for "+ + "payment %v: %w", paymentHash, err) + } + + return nil +} + +// computePaymentStatusFromDB computes the payment status by fetching minimal +// data from the database. This is a lightweight query optimized for SQL that +// doesn't load route data, making it significantly more efficient than +// FetchPayment when only the status is needed. +func (s *SQLStore) computePaymentStatusFromDB(ctx context.Context, + db SQLQueries, dbPayment sqlc.PaymentAndIntent) (PaymentStatus, error) { + + payment := dbPayment.GetPayment() + + // Use the batch-optimized wrapper to fetch resolution types. + resolutionTypes, err := s.loadPaymentResolutions(ctx, db, payment.ID) + if err != nil { + return 0, fmt.Errorf("failed to load payment resolutions: %w", + err) + } + + // Use the lightweight status computation. + status, err := computePaymentStatus(resolutionTypes, payment.FailReason) + if err != nil { + return 0, fmt.Errorf("failed to compute payment status: %w", + err) + } + + return status, nil +} + +// DeletePayment removes a payment or its failed HTLC attempts from the +// database based on the failedAttemptsOnly flag. +// +// If failedAttemptsOnly is true, this method deletes only the failed HTLC +// attempts for the payment while preserving the payment record itself and any +// successful or in-flight attempts. This is useful for cleaning up historical +// failed attempts after a payment reaches a terminal state. +// +// If failedAttemptsOnly is false, this method deletes the entire payment +// record including all payment metadata, payment creation info, all HTLC +// attempts (both failed and successful), and associated data such as payment +// intents and custom records. +// +// Before deletion, this method validates the payment status to ensure it's +// safe to delete: +// - StatusInitiated: Can be deleted (no HTLCs sent yet) +// - StatusInFlight: Cannot be deleted, returns ErrPaymentInFlight (active +// HTLCs on the network) +// - StatusSucceeded: Can be deleted (payment completed successfully) +// - StatusFailed: Can be deleted (payment has failed permanently) +// +// Returns an error if the payment has in-flight HTLCs or if the payment +// doesn't exist. +// +// This method is part of the PaymentWriter interface, which is embedded in +// the DB interface. +func (s *SQLStore) DeletePayment(paymentHash lntypes.Hash, + failedHtlcsOnly bool) error { + + ctx := context.TODO() + + // In case we are configured to keep failed payment attempts, we exit + // early. + // + // TODO(ziggie): Refactor to not mix application logic with database + // logic. This decision should be made in the application layer. + if s.keepFailedPaymentAttempts { + return nil + } + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + dbPayment, err := fetchPaymentByHash(ctx, db, paymentHash) + if err != nil { + return err + } + + paymentStatus, err := s.computePaymentStatusFromDB( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to compute payment "+ + "status: %w", err) + } + + if err := paymentStatus.removable(); err != nil { + return fmt.Errorf("payment %v cannot be deleted: %w", + paymentHash, err) + } + + // If we are only deleting failed HTLCs, we delete them. + if failedHtlcsOnly { + return db.DeleteFailedAttempts( + ctx, dbPayment.GetPayment().ID, + ) + } + + // In case we are not deleting failed HTLCs, we delete the + // payment which will cascade delete all related data. + return db.DeletePayment(ctx, dbPayment.GetPayment().ID) + }, sqldb.NoOpReset) + if err != nil { + return fmt.Errorf("failed to delete failed attempts for "+ + "payment %v: %w", paymentHash, err) + } + + return nil +} + +// InitPayment creates a new payment record in the database with the given +// payment hash and creation info. +// +// Before creating the payment, this method checks if a payment with the same +// hash already exists and validates whether initialization is allowed based on +// the existing payment's status: +// - StatusInitiated: Returns ErrPaymentExists (payment already created, +// HTLCs may be in flight) +// - StatusInFlight: Returns ErrPaymentInFlight (payment currently being +// attempted) +// - StatusSucceeded: Returns ErrAlreadyPaid (payment already succeeded) +// - StatusFailed: Allows retry by deleting the old payment record and +// creating a new one +// +// If no existing payment is found, a new payment record is created with +// StatusInitiated and stored with all associated metadata. +// +// This method is part of the PaymentControl interface, which is embedded in +// the PaymentWriter interface and ultimately the DB interface, representing +// the first step in the payment lifecycle control flow. +func (s *SQLStore) InitPayment(paymentHash lntypes.Hash, + paymentCreationInfo *PaymentCreationInfo) error { + + ctx := context.TODO() + + // Create the payment in the database. + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + existingPayment, err := db.FetchPayment(ctx, paymentHash[:]) + switch { + // A payment with this hash already exists. We need to check its + // status to see if we can re-initialize. + case err == nil: + paymentStatus, err := s.computePaymentStatusFromDB( + ctx, db, existingPayment, + ) + if err != nil { + return fmt.Errorf("failed to compute payment "+ + "status: %w", err) + } + + // Check if the payment is initializable otherwise + // we'll return early. + if err := paymentStatus.initializable(); err != nil { + return fmt.Errorf("payment is not "+ + "initializable: %w", err) + } + + // If the initializable check above passes, then the + // existing payment has failed. So we delete it and + // all of its previous artifacts. We rely on + // cascading deletes to clean up the rest. + err = db.DeletePayment(ctx, existingPayment.Payment.ID) + if err != nil { + return fmt.Errorf("failed to delete "+ + "payment: %w", err) + } + + // An unexpected error occurred while fetching the payment. + case !errors.Is(err, sql.ErrNoRows): + // Some other error occurred + return fmt.Errorf("failed to check existing "+ + "payment: %w", err) + + // The payment does not yet exist, so we can proceed. + default: + } + + // Insert the payment first to get its ID. + paymentID, err := db.InsertPayment( + ctx, sqlc.InsertPaymentParams{ + AmountMsat: int64( + paymentCreationInfo.Value, + ), + CreatedAt: paymentCreationInfo. + CreationTime.UTC(), + PaymentIdentifier: paymentHash[:], + }, + ) + if err != nil { + return fmt.Errorf("failed to insert payment: %w", err) + } + + // If there's a payment request, insert the payment intent. + if len(paymentCreationInfo.PaymentRequest) > 0 { + _, err = db.InsertPaymentIntent( + ctx, sqlc.InsertPaymentIntentParams{ + PaymentID: paymentID, + IntentType: int16( + PaymentIntentTypeBolt11, + ), + IntentPayload: paymentCreationInfo. + PaymentRequest, + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment intent: %w", err) + } + } + + firstHopCustomRecords := paymentCreationInfo. + FirstHopCustomRecords + + for key, value := range firstHopCustomRecords { + err = db.InsertPaymentFirstHopCustomRecord( + ctx, + sqlc.InsertPaymentFirstHopCustomRecordParams{ + PaymentID: paymentID, + Key: int64(key), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment first hop custom "+ + "record: %w", err) + } + } + + return nil + }, sqldb.NoOpReset) + if err != nil { + return fmt.Errorf("failed to initialize payment: %w", err) + } + + return nil +} + +// insertRouteHops inserts all route hop data for a given set of hops. +func (s *SQLStore) insertRouteHops(ctx context.Context, db SQLQueries, + hops []*route.Hop, attemptID uint64) error { + + for i, hop := range hops { + // Insert the basic route hop data and get the generated ID. + hopID, err := db.InsertRouteHop(ctx, sqlc.InsertRouteHopParams{ + HtlcAttemptIndex: int64(attemptID), + HopIndex: int32(i), + PubKey: hop.PubKeyBytes[:], + Scid: strconv.FormatUint( + hop.ChannelID, 10, + ), + OutgoingTimeLock: int32(hop.OutgoingTimeLock), + AmtToForward: int64(hop.AmtToForward), + MetaData: hop.Metadata, + }) + if err != nil { + return fmt.Errorf("failed to insert route hop: %w", err) + } + + // Insert the per-hop custom records. + if len(hop.CustomRecords) > 0 { + for key, value := range hop.CustomRecords { + err = db.InsertPaymentHopCustomRecord( + ctx, + sqlc.InsertPaymentHopCustomRecordParams{ + HopID: hopID, + Key: int64(key), + Value: value, + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment hop custom "+ + "record: %w", err) + } + } + } + + // Insert MPP data if present. + if hop.MPP != nil { + paymentAddr := hop.MPP.PaymentAddr() + err = db.InsertRouteHopMpp( + ctx, sqlc.InsertRouteHopMppParams{ + HopID: hopID, + PaymentAddr: paymentAddr[:], + TotalMsat: int64(hop.MPP.TotalMsat()), + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop MPP: %w", err) + } + } + + // Insert AMP data if present. + if hop.AMP != nil { + rootShare := hop.AMP.RootShare() + setID := hop.AMP.SetID() + err = db.InsertRouteHopAmp( + ctx, sqlc.InsertRouteHopAmpParams{ + HopID: hopID, + RootShare: rootShare[:], + SetID: setID[:], + ChildIndex: int32(hop.AMP.ChildIndex()), + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop AMP: %w", err) + } + } + + // Insert blinded route data if present. Every hop in the + // blinded path must have an encrypted data record. + if hop.EncryptedData != nil { + // The introduction point has a blinding point set. + var blindingPointBytes []byte + if hop.BlindingPoint != nil { + blindingPointBytes = hop.BlindingPoint. + SerializeCompressed() + } + + // The total amount is only set for the final hop in a + // blinded path. + totalAmtMsat := sql.NullInt64{} + if i == len(hops)-1 { + totalAmtMsat = sql.NullInt64{ + Int64: int64(hop.TotalAmtMsat), + Valid: true, + } + } + + err = db.InsertRouteHopBlinded(ctx, + sqlc.InsertRouteHopBlindedParams{ + HopID: hopID, + EncryptedData: hop.EncryptedData, + BlindingPoint: blindingPointBytes, + BlindedPathTotalAmt: totalAmtMsat, + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop blinded: %w", err) + } + } + } + + return nil +} + +// RegisterAttempt atomically records a new HTLC attempt for the specified +// payment. The attempt includes the attempt ID, session key, route information +// (hops, timelocks, amounts), and optional data such as MPP/AMP parameters, +// blinded route data, and custom records. +// +// Returns the updated MPPayment with the new attempt appended to the HTLCs +// slice, and the payment state recalculated. Returns an error if the payment +// doesn't exist or validation fails. +// +// This method is part of the PaymentControl interface, which is embedded in +// the PaymentWriter interface and ultimately the DB interface. It represents +// step 2 in the payment lifecycle control flow, called after InitPayment and +// potentially multiple times for multi-path payments. +func (s *SQLStore) RegisterAttempt(paymentHash lntypes.Hash, + attempt *HTLCAttemptInfo) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // Make sure the payment exists. + dbPayment, err := fetchPaymentByHash(ctx, db, paymentHash) + if err != nil { + return err + } + + // We fetch the complete payment to determine if the payment is + // registrable. + // + // TODO(ziggie): We could improve the query here since only + // the last hop data is needed here not the complete payment + // data. + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := mpPayment.Registrable(); err != nil { + return fmt.Errorf("htlc attempt not registrable: %w", + err) + } + + // Verify the attempt is compatible with the existing payment. + if err := verifyAttempt(mpPayment, attempt); err != nil { + return fmt.Errorf("failed to verify attempt: %w", err) + } + + // Register the plain HTLC attempt next. + sessionKey := attempt.SessionKey() + sessionKeyBytes := sessionKey.Serialize() + + _, err = db.InsertHtlcAttempt(ctx, sqlc.InsertHtlcAttemptParams{ + PaymentID: dbPayment.GetPayment().ID, + AttemptIndex: int64(attempt.AttemptID), + SessionKey: sessionKeyBytes, + AttemptTime: attempt.AttemptTime, + PaymentHash: paymentHash[:], + FirstHopAmountMsat: int64( + attempt.Route.FirstHopAmount.Val.Int(), + ), + RouteTotalTimeLock: int32(attempt.Route.TotalTimeLock), + RouteTotalAmount: int64(attempt.Route.TotalAmount), + RouteSourceKey: attempt.Route.SourcePubKey[:], + }) + if err != nil { + return fmt.Errorf("failed to insert HTLC "+ + "attempt: %w", err) + } + + // Insert the route level first hop custom records. + attemptFirstHopCustomRecords := attempt.Route. + FirstHopWireCustomRecords + + for key, value := range attemptFirstHopCustomRecords { + //nolint:ll + err = db.InsertPaymentAttemptFirstHopCustomRecord( + ctx, + sqlc.InsertPaymentAttemptFirstHopCustomRecordParams{ + HtlcAttemptIndex: int64(attempt.AttemptID), + Key: int64(key), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment attempt first hop custom "+ + "record: %w", err) + } + } + + // Insert the route hops. + err = s.insertRouteHops( + ctx, db, attempt.Route.Hops, attempt.AttemptID, + ) + if err != nil { + return fmt.Errorf("failed to insert route hops: %w", + err) + } + + // Add the attempt to the payment without fetching it from the + // DB again. + mpPayment.HTLCs = append(mpPayment.HTLCs, HTLCAttempt{ + HTLCAttemptInfo: *attempt, + }) + + if err := mpPayment.SetState(); err != nil { + return fmt.Errorf("failed to set payment state: %w", + err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to register attempt: %w", err) + } + + return mpPayment, nil +} + +// SettleAttempt marks the specified HTLC attempt as successfully settled, +// recording the payment preimage and settlement time. The preimage serves as +// cryptographic proof of payment and is atomically saved to the database. +// +// This method is part of the PaymentControl interface, which is embedded in +// the PaymentWriter interface and ultimately the DB interface. It represents +// step 3a in the payment lifecycle control flow (step 3b is FailAttempt), +// called after RegisterAttempt when an HTLC successfully completes. +func (s *SQLStore) SettleAttempt(paymentHash lntypes.Hash, + attemptID uint64, settleInfo *HTLCSettleInfo) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + dbPayment, err := fetchPaymentByHash(ctx, db, paymentHash) + if err != nil { + return err + } + + paymentStatus, err := s.computePaymentStatusFromDB( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to compute payment "+ + "status: %w", err) + } + + if err := paymentStatus.updatable(); err != nil { + return fmt.Errorf("payment is not updatable: %w", err) + } + + err = db.SettleAttempt(ctx, sqlc.SettleAttemptParams{ + AttemptIndex: int64(attemptID), + ResolutionTime: time.Now(), + ResolutionType: int32(HTLCAttemptResolutionSettled), + SettlePreimage: settleInfo.Preimage[:], + }) + if err != nil { + return fmt.Errorf("failed to settle attempt: %w", err) + } + + // Fetch the complete payment after we settled the attempt. + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to settle attempt: %w", err) + } + + return mpPayment, nil +} + +// FailAttempt marks the specified HTLC attempt as failed, recording the +// failure reason, failure time, optional failure message, and the index of the +// node in the route that generated the failure. This information is atomically +// saved to the database for debugging and route optimization purposes. +// +// For single-path payments, failing the only attempt may lead to the payment +// being retried or ultimately failed via the Fail method. For multi-shard +// (MPP/AMP) payments, individual shard failures don't necessarily fail the +// entire payment; additional attempts can be registered until sufficient shards +// succeed or the payment is permanently failed. +// +// Returns the updated MPPayment with the attempt marked as failed and the +// payment state recalculated. The payment status remains StatusInFlight if +// other attempts are still in flight, or may transition based on the overall +// payment state. +// +// This method is part of the PaymentControl interface, which is embedded in +// the PaymentWriter interface and ultimately the DB interface. It represents +// step 3b in the payment lifecycle control flow (step 3a is SettleAttempt), +// called after RegisterAttempt when an HTLC fails. +func (s *SQLStore) FailAttempt(paymentHash lntypes.Hash, + attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // Make sure the payment exists. + dbPayment, err := fetchPaymentByHash(ctx, db, paymentHash) + if err != nil { + return err + } + + paymentStatus, err := s.computePaymentStatusFromDB( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to compute payment "+ + "status: %w", err) + } + + // We check if the payment is updatable before failing the + // attempt. + if err := paymentStatus.updatable(); err != nil { + return fmt.Errorf("payment is not updatable: %w", err) + } + + var failureMsg bytes.Buffer + if failInfo.Message != nil { + err := lnwire.EncodeFailureMessage( + &failureMsg, failInfo.Message, 0, + ) + if err != nil { + return fmt.Errorf("failed to encode "+ + "failure message: %w", err) + } + } + + err = db.FailAttempt(ctx, sqlc.FailAttemptParams{ + AttemptIndex: int64(attemptID), + ResolutionTime: time.Now(), + ResolutionType: int32(HTLCAttemptResolutionFailed), + FailureSourceIndex: sqldb.SQLInt32( + failInfo.FailureSourceIndex, + ), + HtlcFailReason: sqldb.SQLInt32(failInfo.Reason), + FailureMsg: failureMsg.Bytes(), + }) + if err != nil { + return fmt.Errorf("failed to fail attempt: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with"+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fail attempt: %w", err) + } + + return mpPayment, nil +} + +// Fail records the ultimate reason why a payment failed. This method stores +// the failure reason for record keeping but does not enforce that all HTLC +// attempts are resolved - HTLCs may still be in flight when this is called. +// +// The payment's actual status transition to StatusFailed is determined by the +// payment state calculation, which considers both the recorded failure reason +// and the current state of all HTLC attempts. The status will transition to +// StatusFailed once all HTLCs are resolved and/or a failure reason is recorded. +// +// NOTE: According to the interface contract, this should only be called when +// all active attempts are already failed. However, the implementation allows +// concurrent calls and does not validate this precondition, enabling the last +// failing attempt to record the failure reason without synchronization. +// +// This method is part of the PaymentControl interface, which is embedded in +// the PaymentWriter interface and ultimately the DB interface. It represents +// step 4 in the payment lifecycle control flow. +func (s *SQLStore) Fail(paymentHash lntypes.Hash, + reason FailureReason) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + result, err := db.FailPayment(ctx, sqlc.FailPaymentParams{ + PaymentIdentifier: paymentHash[:], + FailReason: sqldb.SQLInt32(reason), + }) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return ErrPaymentNotInitiated + } + + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fail payment: %w", err) + } + + return mpPayment, nil +} + +// DeletePayments performs a batch deletion of payments or their failed HTLC +// attempts from the database based on the specified flags. This is a bulk +// operation that iterates through all payments and selectively deletes based +// on the criteria. +// The behavior is controlled by two flags: +// +// If failedAttemptsOnly is true, only failed HTLC attempts are deleted while +// preserving the payment records and any successful or in-flight attempts. +// The return value is always 0 when deleting attempts only. +// +// If failedAttemptsOnly is false, entire payment records are deleted including +// all associated data (HTLCs, metadata, intents). The return value is the +// number of payments deleted. +// +// The failedOnly flag further filters which payments are processed: +// - failedOnly=true, failedAttemptsOnly=true: Delete failed attempts for +// StatusFailed payments only +// - failedOnly=false, failedAttemptsOnly=true: Delete failed attempts for +// all removable payments +// - failedOnly=true, failedAttemptsOnly=false: Delete entire payment records +// for StatusFailed payments only +// - failedOnly=false, failedAttemptsOnly=false: Delete all removable payment +// records (StatusInitiated, StatusSucceeded, StatusFailed) +// +// Safety checks applied to all operations: +// - Payments with StatusInFlight are always skipped (cannot be safely deleted +// while HTLCs are on the network) +// - The payment status must pass the removable() check +// +// Returns the number of complete payments deleted (0 if only deleting failed +// attempts). This is useful for cleanup operations, administrative maintenance, +// or freeing up database storage. +// +// This method is part of the PaymentWriter interface, which is embedded in +// the DB interface. +// +// TODO(ziggie): batch this call instead in the background so for dbs with +// many payments it doesn't block the main thread. +func (s *SQLStore) DeletePayments(failedOnly, failedHtlcsOnly bool) (int, + error) { + + var numPayments int + ctx := context.TODO() + + extractCursor := func( + row sqlc.FilterPaymentsRow) int64 { + + return row.Payment.ID + } + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // collectFunc extracts the payment ID from each payment row. + collectFunc := func(row sqlc.FilterPaymentsRow) (int64, + error) { + + return row.Payment.ID, nil + } + + // batchDataFunc loads only HTLC resolution types for a batch + // of payments, which is sufficient to determine payment status. + batchDataFunc := func(ctx context.Context, paymentIDs []int64) ( + *paymentStatusBatchData, error) { + + return s.loadPaymentResolutionsBatchData( + ctx, db, paymentIDs, + ) + } + + // processPayment processes each payment with the lightweight + // batch-loaded resolution data. + processPayment := func(ctx context.Context, + dbPayment sqlc.FilterPaymentsRow, + batchData *paymentStatusBatchData) error { + + payment := dbPayment.Payment + + // Compute the payment status from resolution types and + // failure reason without building the complete payment. + resolutionTypes := batchData.resolutionTypes[payment.ID] + status, err := computePaymentStatus( + resolutionTypes, payment.FailReason, + ) + if err != nil { + return fmt.Errorf("failed to compute payment "+ + "status: %w", err) + } + + // Payments which are not final yet cannot be deleted. + // we skip them. + if err := status.removable(); err != nil { + return nil + } + + // If we are only deleting failed payments, we skip + // if the payment is not failed. + if failedOnly && status != StatusFailed { + return nil + } + + // If we are only deleting failed HTLCs, we delete them + // and return early. + if failedHtlcsOnly { + return db.DeleteFailedAttempts( + ctx, payment.ID, + ) + } + + // Otherwise we delete the payment. + err = db.DeletePayment(ctx, payment.ID) + if err != nil { + return fmt.Errorf("failed to delete "+ + "payment: %w", err) + } + + numPayments++ + + return nil + } + + queryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.FilterPaymentsRow, error) { + + filterParams := sqlc.FilterPaymentsParams{ + NumLimit: limit, + // For now there are only BOLT 11 payment + // intents. + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + IndexOffsetGet: sqldb.SQLInt64( + lastID, + ), + } + + return db.FilterPayments(ctx, filterParams) + } + + return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, s.cfg.QueryCfg, int64(-1), queryFunc, + extractCursor, collectFunc, batchDataFunc, + processPayment, + ) + }, func() { + numPayments = 0 + }) + if err != nil { + return 0, fmt.Errorf("failed to delete payments "+ + "(failedOnly: %v, failedHtlcsOnly: %v): %w", + failedOnly, failedHtlcsOnly, err) + } + + return numPayments, nil +} diff --git a/payments/db/test_harness.go b/payments/db/test_harness.go new file mode 100644 index 0000000000..11f88c3f83 --- /dev/null +++ b/payments/db/test_harness.go @@ -0,0 +1,26 @@ +package paymentsdb + +import ( + "testing" + + "github.com/lightningnetwork/lnd/lntypes" +) + +// TestHarness provides implementation-specific test utilities for the payments +// database. Different database backends (KV, SQL) have different internal +// structures and indexing mechanisms, so this interface allows tests to verify +// implementation-specific behavior without coupling the test logic to a +// particular backend. +type TestHarness interface { + // AssertPaymentIndex checks that a payment is correctly indexed. + // For KV: verifies the payment index bucket entry exists and points + // to the correct payment hash. + // For SQL: no-op (SQL doesn't use a separate index bucket). + AssertPaymentIndex(t *testing.T, expectedHash lntypes.Hash) + + // AssertNoIndex checks that an index for a sequence number doesn't + // exist. + // For KV: verifies the index bucket entry is deleted. + // For SQL: no-op. + AssertNoIndex(t *testing.T, seqNr uint64) +} diff --git a/payments/db/test_kvdb.go b/payments/db/test_kvdb.go index e0ee1738d7..ed1710b14f 100644 --- a/payments/db/test_kvdb.go +++ b/payments/db/test_kvdb.go @@ -1,14 +1,20 @@ +//go:build !test_db_sqlite && !test_db_postgres + package paymentsdb import ( + "bytes" "testing" + "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, opts ...OptionModifier) DB { +func NewTestDB(t *testing.T, opts ...OptionModifier) (DB, TestHarness) { backend, backendCleanup, err := kvdb.GetTestBackend( t.TempDir(), "paymentsDB", ) @@ -19,7 +25,7 @@ func NewTestDB(t *testing.T, opts ...OptionModifier) DB { paymentDB, err := NewKVStore(backend, opts...) require.NoError(t, err) - return paymentDB + return paymentDB, &kvTestHarness{db: paymentDB} } // NewKVTestDB is a helper function that creates an BBolt database for testing @@ -38,3 +44,68 @@ func NewKVTestDB(t *testing.T, opts ...OptionModifier) *KVStore { return paymentDB } + +// kvTestHarness is the KV-specific test harness implementation. +type kvTestHarness struct { + db *KVStore +} + +// AssertPaymentIndex looks up the index for a payment in the db and checks +// that its payment hash matches the expected hash passed in. +func (h *kvTestHarness) AssertPaymentIndex(t *testing.T, + expectedHash lntypes.Hash) { + + t.Helper() + + // Lookup the payment so that we have its sequence number and check + // that it has correctly been indexed in the payment indexes bucket. + pmt, err := h.db.FetchPayment(expectedHash) + require.NoError(t, err) + + hash, err := h.fetchPaymentIndexEntry(t, pmt.SequenceNum) + require.NoError(t, err) + assert.Equal(t, expectedHash, *hash) +} + +// AssertNoIndex checks that an index for the sequence number provided does not +// exist. +func (h *kvTestHarness) AssertNoIndex(t *testing.T, seqNr uint64) { + t.Helper() + + _, err := h.fetchPaymentIndexEntry(t, seqNr) + require.Equal(t, ErrNoSequenceNrIndex, err) +} + +// fetchPaymentIndexEntry gets the payment hash for the sequence number +// provided from the payment indexes bucket. +func (h *kvTestHarness) fetchPaymentIndexEntry(t *testing.T, + sequenceNumber uint64) (*lntypes.Hash, error) { + + t.Helper() + + var hash lntypes.Hash + + if err := kvdb.View(h.db.db, func(tx walletdb.ReadTx) error { + indexBucket := tx.ReadBucket(paymentsIndexBucket) + key := make([]byte, 8) + byteOrder.PutUint64(key, sequenceNumber) + + indexValue := indexBucket.Get(key) + if indexValue == nil { + return ErrNoSequenceNrIndex + } + + r := bytes.NewReader(indexValue) + + var err error + hash, err = deserializePaymentIndex(r) + + return err + }, func() { + hash = lntypes.Hash{} + }); err != nil { + return nil, err + } + + return &hash, nil +} diff --git a/payments/db/test_postgres.go b/payments/db/test_postgres.go new file mode 100644 index 0000000000..bd22703f1f --- /dev/null +++ b/payments/db/test_postgres.go @@ -0,0 +1,95 @@ +//go:build test_db_postgres && !test_db_sqlite + +package paymentsdb + +import ( + "database/sql" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB, opts ...OptionModifier) (DB, TestHarness) { + db := NewTestDBWithFixture(t, nil, opts...) + return db, &noopTestHarness{} +} + +// NewTestDBFixture creates a new sqldb.TestPgFixture for testing purposes. +func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture { + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + return pgFixture +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture, opts ...OptionModifier) DB { + + var querier BatchedSQLQueries + if pgFixture == nil { + querier = newBatchQuerier(t) + } else { + querier = newBatchQuerierWithFixture(t, pgFixture) + } + + store, err := NewSQLStore( + &SQLStoreConfig{ + QueryCfg: sqldb.DefaultPostgresConfig(), + }, querier, opts..., + ) + require.NoError(t, err) + + return store +} + +// newBatchQuerier creates a new BatchedSQLQueries instance for testing +// using a PostgreSQL database fixture. +func newBatchQuerier(t testing.TB) BatchedSQLQueries { + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return newBatchQuerierWithFixture(t, pgFixture) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a PostgreSQL database fixture. +func newBatchQuerierWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture) BatchedSQLQueries { + + db := sqldb.NewTestPostgresDB(t, pgFixture).BaseDB + + return sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLQueries { + return db.WithTx(tx) + }, + ) +} + +// noopTestHarness is the SQL test harness implementation. Since SQL doesn't +// use a separate payment index bucket like KV, these assertions are no-ops. +type noopTestHarness struct{} + +// AssertPaymentIndex is a no-op for SQL implementations. +func (h *noopTestHarness) AssertPaymentIndex(t *testing.T, + expectedHash lntypes.Hash) { + + // No-op: SQL doesn't use a separate index bucket. +} + +// AssertNoIndex is a no-op for SQL implementations. +func (h *noopTestHarness) AssertNoIndex(t *testing.T, seqNr uint64) { + // No-op: SQL doesn't use a separate index bucket. +} diff --git a/payments/db/test_sqlite.go b/payments/db/test_sqlite.go new file mode 100644 index 0000000000..99d1047805 --- /dev/null +++ b/payments/db/test_sqlite.go @@ -0,0 +1,74 @@ +//go:build !test_db_postgres && test_db_sqlite + +package paymentsdb + +import ( + "database/sql" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB, opts ...OptionModifier) (DB, TestHarness) { + db := NewTestDBWithFixture(t, nil, opts...) + return db, &noopTestHarness{} +} + +// NewTestDBFixture is a no-op for the sqlite build. +func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture { + return nil +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture, + opts ...OptionModifier) DB { + + store, err := NewSQLStore( + &SQLStoreConfig{ + QueryCfg: sqldb.DefaultSQLiteConfig(), + }, newBatchQuerier(t), opts..., + ) + require.NoError(t, err) + return store +} + +// newBatchQuerier creates a new BatchedSQLQueries instance for testing +// using a SQLite database. +func newBatchQuerier(t testing.TB) BatchedSQLQueries { + return newBatchQuerierWithFixture(t, nil) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a SQLite database. +func newBatchQuerierWithFixture(t testing.TB, + _ *sqldb.TestPgFixture) BatchedSQLQueries { + + db := sqldb.NewTestSqliteDB(t).BaseDB + + return sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLQueries { + return db.WithTx(tx) + }, + ) +} + +// noopTestHarness is the SQL test harness implementation. Since SQL doesn't +// use a separate payment index bucket like KV, these assertions are no-ops. +type noopTestHarness struct{} + +// AssertPaymentIndex is a no-op for SQL implementations. +func (h *noopTestHarness) AssertPaymentIndex(t *testing.T, + expectedHash lntypes.Hash) { + + // No-op: SQL doesn't use a separate index bucket. +} + +// AssertNoIndex is a no-op for SQL implementations. +func (h *noopTestHarness) AssertNoIndex(t *testing.T, seqNr uint64) { + // No-op: SQL doesn't use a separate index bucket. +} diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go index 7888f81f45..e9a3cb4af2 100644 --- a/sqldb/sqlc/db_custom.go +++ b/sqldb/sqlc/db_custom.go @@ -221,25 +221,3 @@ func (r FetchPaymentRow) GetPaymentIntent() PaymentIntent { IntentPayload: r.IntentPayload, } } - -// GetPayment returns the Payment associated with this interface. -// -// NOTE: This method is part of the PaymentAndIntent interface. -func (r FetchPaymentsByIDsRow) GetPayment() Payment { - return r.Payment -} - -// GetPaymentIntent returns the PaymentIntent associated with this payment. -// If the payment has no intent (IntentType is NULL), this returns a zero-value -// PaymentIntent. -// -// NOTE: This method is part of the PaymentAndIntent interface. -func (r FetchPaymentsByIDsRow) GetPaymentIntent() PaymentIntent { - if !r.IntentType.Valid { - return PaymentIntent{} - } - return PaymentIntent{ - IntentType: r.IntentType.Int16, - IntentPayload: r.IntentPayload, - } -} diff --git a/sqldb/sqlc/migrations/000009_payments.up.sql b/sqldb/sqlc/migrations/000009_payments.up.sql index 0d85b497b0..65094a15e4 100644 --- a/sqldb/sqlc/migrations/000009_payments.up.sql +++ b/sqldb/sqlc/migrations/000009_payments.up.sql @@ -2,43 +2,12 @@ -- Payment System Schema Migration -- ───────────────────────────────────────────── -- This migration creates the complete payment system schema including: --- - Payment intents (BOLT 11/12 invoices, offers) +-- - Payment intents (only BOLT 11 invoices for now) -- - Payment attempts and HTLC tracking -- - Route hops and custom TLV records -- - Resolution tracking for settled/failed payments -- ───────────────────────────────────────────── --- ───────────────────────────────────────────── --- Payment Intents Table --- ───────────────────────────────────────────── --- Stores the descriptor of what the payment is paying for. --- Depending on the type, the payload might contain: --- - BOLT 11 invoice data --- - BOLT 12 offer data --- - NULL for legacy hash-only/keysend style payments --- ───────────────────────────────────────────── - -CREATE TABLE IF NOT EXISTS payment_intents ( - -- Primary key for the intent record - id INTEGER PRIMARY KEY, - - -- The type of intent (e.g. 0 = bolt11_invoice, 1 = bolt12_offer) - -- Uses SMALLINT (int16) for efficient storage of enum values - intent_type SMALLINT NOT NULL, - - -- The serialized payload for the payment intent - -- Content depends on type - could be invoice, offer, or NULL - intent_payload BLOB -); - --- Index for efficient querying by intent type -CREATE INDEX IF NOT EXISTS idx_payment_intents_type -ON payment_intents(intent_type); - --- Unique constraint for deduplication of payment intents -CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_intents_unique -ON payment_intents(intent_type, intent_payload); - -- ───────────────────────────────────────────── -- Payments Table -- ───────────────────────────────────────────── @@ -55,10 +24,6 @@ CREATE TABLE IF NOT EXISTS payments ( -- Primary key for the payment record id INTEGER PRIMARY KEY, - -- Optional reference to the payment intent this payment was derived from - -- Links to BOLT 11 invoice, BOLT 12 offer, etc. - intent_id BIGINT REFERENCES payment_intents (id), - -- The amount of the payment in millisatoshis amount_msat BIGINT NOT NULL, @@ -70,20 +35,59 @@ CREATE TABLE IF NOT EXISTS payments ( -- For AMP: the setID -- For future intent types: any unique payment-level key payment_identifier BLOB NOT NULL, - + -- The reason for payment failure (only set if payment has failed) -- Integer enum type indicating failure reason fail_reason INTEGER, -- Ensure payment identifiers are unique across all payments - CONSTRAINT idx_payments_payment_identifier_unique + CONSTRAINT idx_payments_payment_identifier_unique UNIQUE (payment_identifier) ); -- Index for efficient querying by creation time (for chronological ordering) -CREATE INDEX IF NOT EXISTS idx_payments_created_at +CREATE INDEX IF NOT EXISTS idx_payments_created_at ON payments(created_at); +-- ───────────────────────────────────────────── +-- Payment Intents Table +-- ───────────────────────────────────────────── +-- Stores the descriptor of what the payment is paying for. +-- Depending on the type, the payload might contain: +-- - BOLT 11 invoice data +-- - BOLT 12 offer data +-- - NULL for legacy hash-only/keysend style payments +-- ───────────────────────────────────────────── + +CREATE TABLE IF NOT EXISTS payment_intents ( + -- Primary key for the intent record + id INTEGER PRIMARY KEY, + + -- Reference to the payment this intent belongs to (one-to-one relationship) + -- When the payment is deleted, the intent is automatically deleted + payment_id BIGINT NOT NULL REFERENCES payments (id) ON DELETE CASCADE, + + -- The type of intent (e.g. 0 = bolt11_invoice, 1 = bolt12_invoice) + -- Uses SMALLINT (int16) for efficient storage of enum values + intent_type SMALLINT NOT NULL, + + -- The serialized payload for the payment intent + -- Content depends on type - could be invoice, offer, or NULL + intent_payload BLOB, + + -- Ensure one-to-one relationship: each payment has at most one intent. + -- Currently we only support one intent per payment this makes sure we do + -- not accidentally pay the same request multiple times. This currently + -- only has bolt 11 payment requests/invoices. But in the future this can + -- also include BOLT 12 offers/invoices. + CONSTRAINT idx_payment_intents_payment_id_unique + UNIQUE (payment_id) +); + +-- Index for efficient querying by intent type +CREATE INDEX IF NOT EXISTS idx_payment_intents_type +ON payment_intents(intent_type); + -- ───────────────────────────────────────────── -- Payment HTLC Attempts Table -- ───────────────────────────────────────────── diff --git a/sqldb/sqlc/models.go b/sqldb/sqlc/models.go index 6a4c9dd33b..97df2d6f6c 100644 --- a/sqldb/sqlc/models.go +++ b/sqldb/sqlc/models.go @@ -205,7 +205,6 @@ type MigrationTracker struct { type Payment struct { ID int64 - IntentID sql.NullInt64 AmountMsat int64 CreatedAt time.Time PaymentIdentifier []byte @@ -258,6 +257,7 @@ type PaymentHtlcAttemptResolution struct { type PaymentIntent struct { ID int64 + PaymentID int64 IntentType int16 IntentPayload []byte } diff --git a/sqldb/sqlc/payments.sql.go b/sqldb/sqlc/payments.sql.go index 83c1c7f1f0..4ef52558e0 100644 --- a/sqldb/sqlc/payments.sql.go +++ b/sqldb/sqlc/payments.sql.go @@ -23,6 +23,81 @@ func (q *Queries) CountPayments(ctx context.Context) (int64, error) { return count, err } +const deleteFailedAttempts = `-- name: DeleteFailedAttempts :exec +DELETE FROM payment_htlc_attempts WHERE payment_id = $1 AND attempt_index IN ( + SELECT attempt_index FROM payment_htlc_attempt_resolutions WHERE resolution_type = 2 +) +` + +// Delete all failed HTLC attempts for the given payment. Resolution type 2 +// indicates a failed attempt. +func (q *Queries) DeleteFailedAttempts(ctx context.Context, paymentID int64) error { + _, err := q.db.ExecContext(ctx, deleteFailedAttempts, paymentID) + return err +} + +const deletePayment = `-- name: DeletePayment :exec +DELETE FROM payments WHERE id = $1 +` + +func (q *Queries) DeletePayment(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deletePayment, id) + return err +} + +const failAttempt = `-- name: FailAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + failure_source_index, + htlc_fail_reason, + failure_msg +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) +` + +type FailAttemptParams struct { + AttemptIndex int64 + ResolutionTime time.Time + ResolutionType int32 + FailureSourceIndex sql.NullInt32 + HtlcFailReason sql.NullInt32 + FailureMsg []byte +} + +func (q *Queries) FailAttempt(ctx context.Context, arg FailAttemptParams) error { + _, err := q.db.ExecContext(ctx, failAttempt, + arg.AttemptIndex, + arg.ResolutionTime, + arg.ResolutionType, + arg.FailureSourceIndex, + arg.HtlcFailReason, + arg.FailureMsg, + ) + return err +} + +const failPayment = `-- name: FailPayment :execresult +UPDATE payments SET fail_reason = $1 WHERE payment_identifier = $2 +` + +type FailPaymentParams struct { + FailReason sql.NullInt32 + PaymentIdentifier []byte +} + +func (q *Queries) FailPayment(ctx context.Context, arg FailPaymentParams) (sql.Result, error) { + return q.db.ExecContext(ctx, failPayment, arg.FailReason, arg.PaymentIdentifier) +} + const fetchAllInflightAttempts = `-- name: FetchAllInflightAttempts :many SELECT ha.id, @@ -34,25 +109,60 @@ SELECT ha.first_hop_amount_msat, ha.route_total_time_lock, ha.route_total_amount, - ha.route_source_key + ha.route_source_key, + p.amount_msat, + p.created_at, + p.payment_identifier, + p.fail_reason, + pi.intent_type, + pi.intent_payload FROM payment_htlc_attempts ha +INNER JOIN payments p ON p.id = ha.payment_id +LEFT JOIN payment_intents pi ON pi.payment_id = p.id WHERE NOT EXISTS ( SELECT 1 FROM payment_htlc_attempt_resolutions hr WHERE hr.attempt_index = ha.attempt_index ) +AND ha.attempt_index > $1 ORDER BY ha.attempt_index ASC +LIMIT $2 ` -// Fetch all inflight attempts across all payments -func (q *Queries) FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAttempt, error) { - rows, err := q.db.QueryContext(ctx, fetchAllInflightAttempts) +type FetchAllInflightAttemptsParams struct { + AttemptIndex int64 + Limit int32 +} + +type FetchAllInflightAttemptsRow struct { + ID int64 + AttemptIndex int64 + PaymentID int64 + SessionKey []byte + AttemptTime time.Time + PaymentHash []byte + FirstHopAmountMsat int64 + RouteTotalTimeLock int32 + RouteTotalAmount int64 + RouteSourceKey []byte + AmountMsat int64 + CreatedAt time.Time + PaymentIdentifier []byte + FailReason sql.NullInt32 + IntentType sql.NullInt16 + IntentPayload []byte +} + +// Fetch all inflight attempts with their payment data using pagination. +// Returns attempt data joined with payment and intent data to avoid separate queries. +func (q *Queries) FetchAllInflightAttempts(ctx context.Context, arg FetchAllInflightAttemptsParams) ([]FetchAllInflightAttemptsRow, error) { + rows, err := q.db.QueryContext(ctx, fetchAllInflightAttempts, arg.AttemptIndex, arg.Limit) if err != nil { return nil, err } defer rows.Close() - var items []PaymentHtlcAttempt + var items []FetchAllInflightAttemptsRow for rows.Next() { - var i PaymentHtlcAttempt + var i FetchAllInflightAttemptsRow if err := rows.Scan( &i.ID, &i.AttemptIndex, @@ -64,6 +174,12 @@ func (q *Queries) FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAt &i.RouteTotalTimeLock, &i.RouteTotalAmount, &i.RouteSourceKey, + &i.AmountMsat, + &i.CreatedAt, + &i.PaymentIdentifier, + &i.FailReason, + &i.IntentType, + &i.IntentPayload, ); err != nil { return nil, err } @@ -222,6 +338,56 @@ func (q *Queries) FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices [ return items, nil } +const fetchHtlcAttemptResolutionsForPayments = `-- name: FetchHtlcAttemptResolutionsForPayments :many +SELECT + ha.payment_id, + hr.resolution_type +FROM payment_htlc_attempts ha +LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_index +WHERE ha.payment_id IN (/*SLICE:payment_ids*/?) +` + +type FetchHtlcAttemptResolutionsForPaymentsRow struct { + PaymentID int64 + ResolutionType sql.NullInt32 +} + +// Batch query to fetch only HTLC resolution status for multiple payments. +// We don't need to order by payment_id and attempt_time because we will +// group the resolutions by payment_id in the background. +func (q *Queries) FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIds []int64) ([]FetchHtlcAttemptResolutionsForPaymentsRow, error) { + query := fetchHtlcAttemptResolutionsForPayments + var queryParams []interface{} + if len(paymentIds) > 0 { + for _, v := range paymentIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:payment_ids*/?", makeQueryParams(len(queryParams), len(paymentIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:payment_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHtlcAttemptResolutionsForPaymentsRow + for rows.Next() { + var i FetchHtlcAttemptResolutionsForPaymentsRow + if err := rows.Scan(&i.PaymentID, &i.ResolutionType); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const fetchHtlcAttemptsForPayments = `-- name: FetchHtlcAttemptsForPayments :many SELECT ha.id, @@ -317,11 +483,11 @@ func (q *Queries) FetchHtlcAttemptsForPayments(ctx context.Context, paymentIds [ const fetchPayment = `-- name: FetchPayment :one SELECT - p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + p.id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, i.intent_type AS "intent_type", i.intent_payload AS "intent_payload" FROM payments p -LEFT JOIN payment_intents i ON i.id = p.intent_id +LEFT JOIN payment_intents i ON i.payment_id = p.id WHERE p.payment_identifier = $1 ` @@ -336,7 +502,6 @@ func (q *Queries) FetchPayment(ctx context.Context, paymentIdentifier []byte) (F var i FetchPaymentRow err := row.Scan( &i.Payment.ID, - &i.Payment.IntentID, &i.Payment.AmountMsat, &i.Payment.CreatedAt, &i.Payment.PaymentIdentifier, @@ -396,64 +561,6 @@ func (q *Queries) FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, pa return items, nil } -const fetchPaymentsByIDs = `-- name: FetchPaymentsByIDs :many -SELECT - p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, - i.intent_type AS "intent_type", - i.intent_payload AS "intent_payload" -FROM payments p -LEFT JOIN payment_intents i ON i.id = p.intent_id -WHERE p.id IN (/*SLICE:payment_ids*/?) -` - -type FetchPaymentsByIDsRow struct { - Payment Payment - IntentType sql.NullInt16 - IntentPayload []byte -} - -func (q *Queries) FetchPaymentsByIDs(ctx context.Context, paymentIds []int64) ([]FetchPaymentsByIDsRow, error) { - query := fetchPaymentsByIDs - var queryParams []interface{} - if len(paymentIds) > 0 { - for _, v := range paymentIds { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:payment_ids*/?", makeQueryParams(len(queryParams), len(paymentIds)), 1) - } else { - query = strings.Replace(query, "/*SLICE:payment_ids*/?", "NULL", 1) - } - rows, err := q.db.QueryContext(ctx, query, queryParams...) - if err != nil { - return nil, err - } - defer rows.Close() - var items []FetchPaymentsByIDsRow - for rows.Next() { - var i FetchPaymentsByIDsRow - if err := rows.Scan( - &i.Payment.ID, - &i.Payment.IntentID, - &i.Payment.AmountMsat, - &i.Payment.CreatedAt, - &i.Payment.PaymentIdentifier, - &i.Payment.FailReason, - &i.IntentType, - &i.IntentPayload, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const fetchRouteLevelFirstHopCustomRecords = `-- name: FetchRouteLevelFirstHopCustomRecords :many SELECT l.id, @@ -510,11 +617,11 @@ const filterPayments = `-- name: FilterPayments :many */ SELECT - p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + p.id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, i.intent_type AS "intent_type", i.intent_payload AS "intent_payload" FROM payments p -LEFT JOIN payment_intents i ON i.id = p.intent_id +LEFT JOIN payment_intents i ON i.payment_id = p.id WHERE ( p.id > $1 OR $1 IS NULL @@ -572,7 +679,6 @@ func (q *Queries) FilterPayments(ctx context.Context, arg FilterPaymentsParams) var i FilterPaymentsRow if err := rows.Scan( &i.Payment.ID, - &i.Payment.IntentID, &i.Payment.AmountMsat, &i.Payment.CreatedAt, &i.Payment.PaymentIdentifier, @@ -592,3 +698,353 @@ func (q *Queries) FilterPayments(ctx context.Context, arg FilterPaymentsParams) } return items, nil } + +const insertHtlcAttempt = `-- name: InsertHtlcAttempt :one +INSERT INTO payment_htlc_attempts ( + payment_id, + attempt_index, + session_key, + attempt_time, + payment_hash, + first_hop_amount_msat, + route_total_time_lock, + route_total_amount, + route_source_key) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9) +RETURNING id +` + +type InsertHtlcAttemptParams struct { + PaymentID int64 + AttemptIndex int64 + SessionKey []byte + AttemptTime time.Time + PaymentHash []byte + FirstHopAmountMsat int64 + RouteTotalTimeLock int32 + RouteTotalAmount int64 + RouteSourceKey []byte +} + +func (q *Queries) InsertHtlcAttempt(ctx context.Context, arg InsertHtlcAttemptParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertHtlcAttempt, + arg.PaymentID, + arg.AttemptIndex, + arg.SessionKey, + arg.AttemptTime, + arg.PaymentHash, + arg.FirstHopAmountMsat, + arg.RouteTotalTimeLock, + arg.RouteTotalAmount, + arg.RouteSourceKey, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertPayment = `-- name: InsertPayment :one +INSERT INTO payments ( + amount_msat, + created_at, + payment_identifier, + fail_reason) +VALUES ( + $1, + $2, + $3, + NULL +) +RETURNING id +` + +type InsertPaymentParams struct { + AmountMsat int64 + CreatedAt time.Time + PaymentIdentifier []byte +} + +// Insert a new payment and return its ID. +// When creating a payment we don't have a fail reason because we start the +// payment process. +func (q *Queries) InsertPayment(ctx context.Context, arg InsertPaymentParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertPayment, arg.AmountMsat, arg.CreatedAt, arg.PaymentIdentifier) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertPaymentAttemptFirstHopCustomRecord = `-- name: InsertPaymentAttemptFirstHopCustomRecord :exec +INSERT INTO payment_attempt_first_hop_custom_records ( + htlc_attempt_index, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentAttemptFirstHopCustomRecordParams struct { + HtlcAttemptIndex int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg InsertPaymentAttemptFirstHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentAttemptFirstHopCustomRecord, arg.HtlcAttemptIndex, arg.Key, arg.Value) + return err +} + +const insertPaymentFirstHopCustomRecord = `-- name: InsertPaymentFirstHopCustomRecord :exec +INSERT INTO payment_first_hop_custom_records ( + payment_id, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentFirstHopCustomRecordParams struct { + PaymentID int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentFirstHopCustomRecord(ctx context.Context, arg InsertPaymentFirstHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentFirstHopCustomRecord, arg.PaymentID, arg.Key, arg.Value) + return err +} + +const insertPaymentHopCustomRecord = `-- name: InsertPaymentHopCustomRecord :exec +INSERT INTO payment_hop_custom_records ( + hop_id, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentHopCustomRecordParams struct { + HopID int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentHopCustomRecord(ctx context.Context, arg InsertPaymentHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentHopCustomRecord, arg.HopID, arg.Key, arg.Value) + return err +} + +const insertPaymentIntent = `-- name: InsertPaymentIntent :one +INSERT INTO payment_intents ( + payment_id, + intent_type, + intent_payload) +VALUES ( + $1, + $2, + $3 +) +RETURNING id +` + +type InsertPaymentIntentParams struct { + PaymentID int64 + IntentType int16 + IntentPayload []byte +} + +// Insert a payment intent for a given payment and return its ID. +func (q *Queries) InsertPaymentIntent(ctx context.Context, arg InsertPaymentIntentParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertPaymentIntent, arg.PaymentID, arg.IntentType, arg.IntentPayload) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertRouteHop = `-- name: InsertRouteHop :one +INSERT INTO payment_route_hops ( + htlc_attempt_index, + hop_index, + pub_key, + scid, + outgoing_time_lock, + amt_to_forward, + meta_data +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING id +` + +type InsertRouteHopParams struct { + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte +} + +func (q *Queries) InsertRouteHop(ctx context.Context, arg InsertRouteHopParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertRouteHop, + arg.HtlcAttemptIndex, + arg.HopIndex, + arg.PubKey, + arg.Scid, + arg.OutgoingTimeLock, + arg.AmtToForward, + arg.MetaData, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertRouteHopAmp = `-- name: InsertRouteHopAmp :exec +INSERT INTO payment_route_hop_amp ( + hop_id, + root_share, + set_id, + child_index +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertRouteHopAmpParams struct { + HopID int64 + RootShare []byte + SetID []byte + ChildIndex int32 +} + +func (q *Queries) InsertRouteHopAmp(ctx context.Context, arg InsertRouteHopAmpParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopAmp, + arg.HopID, + arg.RootShare, + arg.SetID, + arg.ChildIndex, + ) + return err +} + +const insertRouteHopBlinded = `-- name: InsertRouteHopBlinded :exec +INSERT INTO payment_route_hop_blinded ( + hop_id, + encrypted_data, + blinding_point, + blinded_path_total_amt +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertRouteHopBlindedParams struct { + HopID int64 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) InsertRouteHopBlinded(ctx context.Context, arg InsertRouteHopBlindedParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopBlinded, + arg.HopID, + arg.EncryptedData, + arg.BlindingPoint, + arg.BlindedPathTotalAmt, + ) + return err +} + +const insertRouteHopMpp = `-- name: InsertRouteHopMpp :exec +INSERT INTO payment_route_hop_mpp ( + hop_id, + payment_addr, + total_msat +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertRouteHopMppParams struct { + HopID int64 + PaymentAddr []byte + TotalMsat int64 +} + +func (q *Queries) InsertRouteHopMpp(ctx context.Context, arg InsertRouteHopMppParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopMpp, arg.HopID, arg.PaymentAddr, arg.TotalMsat) + return err +} + +const settleAttempt = `-- name: SettleAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + settle_preimage +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type SettleAttemptParams struct { + AttemptIndex int64 + ResolutionTime time.Time + ResolutionType int32 + SettlePreimage []byte +} + +func (q *Queries) SettleAttempt(ctx context.Context, arg SettleAttemptParams) error { + _, err := q.db.ExecContext(ctx, settleAttempt, + arg.AttemptIndex, + arg.ResolutionTime, + arg.ResolutionType, + arg.SettlePreimage, + ) + return err +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index fd0d3eaff5..d2e291eed9 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -21,24 +21,34 @@ type Querier interface { DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error DeleteChannels(ctx context.Context, ids []int64) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error + // Delete all failed HTLC attempts for the given payment. Resolution type 2 + // indicates a failed attempt. + DeleteFailedAttempts(ctx context.Context, paymentID int64) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) DeleteNode(ctx context.Context, id int64) error DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error + DeletePayment(ctx context.Context, id int64) error DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error) DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) + FailAttempt(ctx context.Context, arg FailAttemptParams) error + FailPayment(ctx context.Context, arg FailPaymentParams) (sql.Result, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) - // Fetch all inflight attempts across all payments - FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAttempt, error) + // Fetch all inflight attempts with their payment data using pagination. + // Returns attempt data joined with payment and intent data to avoid separate queries. + FetchAllInflightAttempts(ctx context.Context, arg FetchAllInflightAttemptsParams) ([]FetchAllInflightAttemptsRow, error) FetchHopLevelCustomRecords(ctx context.Context, hopIds []int64) ([]PaymentHopCustomRecord, error) FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]FetchHopsForAttemptsRow, error) + // Batch query to fetch only HTLC resolution status for multiple payments. + // We don't need to order by payment_id and attempt_time because we will + // group the resolutions by payment_id in the background. + FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIds []int64) ([]FetchHtlcAttemptResolutionsForPaymentsRow, error) FetchHtlcAttemptsForPayments(ctx context.Context, paymentIds []int64) ([]FetchHtlcAttemptsForPaymentsRow, error) FetchPayment(ctx context.Context, paymentIdentifier []byte) (FetchPaymentRow, error) FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentIds []int64) ([]PaymentFirstHopCustomRecord, error) - FetchPaymentsByIDs(ctx context.Context, paymentIds []int64) ([]FetchPaymentsByIDsRow, error) FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]PaymentAttemptFirstHopCustomRecord, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) @@ -112,6 +122,7 @@ type Querier interface { // UpsertEdgePolicy query is used because of the constraint in that query that // requires a policy update to have a newer last_update than the existing one). InsertEdgePolicyMig(ctx context.Context, arg InsertEdgePolicyMigParams) (int64, error) + InsertHtlcAttempt(ctx context.Context, arg InsertHtlcAttemptParams) (int64, error) InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error) @@ -125,6 +136,19 @@ type Querier interface { // is used because of the constraint in that query that requires a node update // to have a newer last_update than the existing node). InsertNodeMig(ctx context.Context, arg InsertNodeMigParams) (int64, error) + // Insert a new payment and return its ID. + // When creating a payment we don't have a fail reason because we start the + // payment process. + InsertPayment(ctx context.Context, arg InsertPaymentParams) (int64, error) + InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg InsertPaymentAttemptFirstHopCustomRecordParams) error + InsertPaymentFirstHopCustomRecord(ctx context.Context, arg InsertPaymentFirstHopCustomRecordParams) error + InsertPaymentHopCustomRecord(ctx context.Context, arg InsertPaymentHopCustomRecordParams) error + // Insert a payment intent for a given payment and return its ID. + InsertPaymentIntent(ctx context.Context, arg InsertPaymentIntentParams) (int64, error) + InsertRouteHop(ctx context.Context, arg InsertRouteHopParams) (int64, error) + InsertRouteHopAmp(ctx context.Context, arg InsertRouteHopAmpParams) error + InsertRouteHopBlinded(ctx context.Context, arg InsertRouteHopBlindedParams) error + InsertRouteHopMpp(ctx context.Context, arg InsertRouteHopMppParams) error IsClosedChannel(ctx context.Context, scid []byte) (bool, error) IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error) IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error) @@ -144,6 +168,7 @@ type Querier interface { OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error SetMigration(ctx context.Context, arg SetMigrationParams) error + SettleAttempt(ctx context.Context, arg SettleAttemptParams) error UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error) UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error) diff --git a/sqldb/sqlc/queries/payments.sql b/sqldb/sqlc/queries/payments.sql index ce43a3e297..b826e7e3fc 100644 --- a/sqldb/sqlc/queries/payments.sql +++ b/sqldb/sqlc/queries/payments.sql @@ -9,7 +9,7 @@ SELECT i.intent_type AS "intent_type", i.intent_payload AS "intent_payload" FROM payments p -LEFT JOIN payment_intents i ON i.id = p.intent_id +LEFT JOIN payment_intents i ON i.payment_id = p.id WHERE ( p.id > sqlc.narg('index_offset_get') OR sqlc.narg('index_offset_get') IS NULL @@ -37,18 +37,9 @@ SELECT i.intent_type AS "intent_type", i.intent_payload AS "intent_payload" FROM payments p -LEFT JOIN payment_intents i ON i.id = p.intent_id +LEFT JOIN payment_intents i ON i.payment_id = p.id WHERE p.payment_identifier = $1; --- name: FetchPaymentsByIDs :many -SELECT - sqlc.embed(p), - i.intent_type AS "intent_type", - i.intent_payload AS "intent_payload" -FROM payments p -LEFT JOIN payment_intents i ON i.id = p.intent_id -WHERE p.id IN (sqlc.slice('payment_ids')/*SLICE:payment_ids*/); - -- name: CountPayments :one SELECT COUNT(*) FROM payments; @@ -75,8 +66,20 @@ LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_i WHERE ha.payment_id IN (sqlc.slice('payment_ids')/*SLICE:payment_ids*/) ORDER BY ha.payment_id ASC, ha.attempt_time ASC; +-- name: FetchHtlcAttemptResolutionsForPayments :many +-- Batch query to fetch only HTLC resolution status for multiple payments. +-- We don't need to order by payment_id and attempt_time because we will +-- group the resolutions by payment_id in the background. +SELECT + ha.payment_id, + hr.resolution_type +FROM payment_htlc_attempts ha +LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_index +WHERE ha.payment_id IN (sqlc.slice('payment_ids')/*SLICE:payment_ids*/); + -- name: FetchAllInflightAttempts :many --- Fetch all inflight attempts across all payments +-- Fetch all inflight attempts with their payment data using pagination. +-- Returns attempt data joined with payment and intent data to avoid separate queries. SELECT ha.id, ha.attempt_index, @@ -87,13 +90,23 @@ SELECT ha.first_hop_amount_msat, ha.route_total_time_lock, ha.route_total_amount, - ha.route_source_key + ha.route_source_key, + p.amount_msat, + p.created_at, + p.payment_identifier, + p.fail_reason, + pi.intent_type, + pi.intent_payload FROM payment_htlc_attempts ha +INNER JOIN payments p ON p.id = ha.payment_id +LEFT JOIN payment_intents pi ON pi.payment_id = p.id WHERE NOT EXISTS ( SELECT 1 FROM payment_htlc_attempt_resolutions hr WHERE hr.attempt_index = ha.attempt_index ) -ORDER BY ha.attempt_index ASC; +AND ha.attempt_index > $1 +ORDER BY ha.attempt_index ASC +LIMIT $2; -- name: FetchHopsForAttempts :many SELECT @@ -151,3 +164,198 @@ FROM payment_hop_custom_records l WHERE l.hop_id IN (sqlc.slice('hop_ids')/*SLICE:hop_ids*/) ORDER BY l.hop_id ASC, l.key ASC; + +-- name: DeletePayment :exec +DELETE FROM payments WHERE id = $1; + +-- name: DeleteFailedAttempts :exec +-- Delete all failed HTLC attempts for the given payment. Resolution type 2 +-- indicates a failed attempt. +DELETE FROM payment_htlc_attempts WHERE payment_id = $1 AND attempt_index IN ( + SELECT attempt_index FROM payment_htlc_attempt_resolutions WHERE resolution_type = 2 +); + +-- name: InsertPaymentIntent :one +-- Insert a payment intent for a given payment and return its ID. +INSERT INTO payment_intents ( + payment_id, + intent_type, + intent_payload) +VALUES ( + @payment_id, + @intent_type, + @intent_payload +) +RETURNING id; + +-- name: InsertPayment :one +-- Insert a new payment and return its ID. +-- When creating a payment we don't have a fail reason because we start the +-- payment process. +INSERT INTO payments ( + amount_msat, + created_at, + payment_identifier, + fail_reason) +VALUES ( + @amount_msat, + @created_at, + @payment_identifier, + NULL +) +RETURNING id; + +-- name: InsertPaymentFirstHopCustomRecord :exec +INSERT INTO payment_first_hop_custom_records ( + payment_id, + key, + value +) +VALUES ( + @payment_id, + @key, + @value +); + +-- name: InsertHtlcAttempt :one +INSERT INTO payment_htlc_attempts ( + payment_id, + attempt_index, + session_key, + attempt_time, + payment_hash, + first_hop_amount_msat, + route_total_time_lock, + route_total_amount, + route_source_key) +VALUES ( + @payment_id, + @attempt_index, + @session_key, + @attempt_time, + @payment_hash, + @first_hop_amount_msat, + @route_total_time_lock, + @route_total_amount, + @route_source_key) +RETURNING id; + +-- name: InsertPaymentAttemptFirstHopCustomRecord :exec +INSERT INTO payment_attempt_first_hop_custom_records ( + htlc_attempt_index, + key, + value +) +VALUES ( + @htlc_attempt_index, + @key, + @value +); + +-- name: InsertRouteHop :one +INSERT INTO payment_route_hops ( + htlc_attempt_index, + hop_index, + pub_key, + scid, + outgoing_time_lock, + amt_to_forward, + meta_data +) +VALUES ( + @htlc_attempt_index, + @hop_index, + @pub_key, + @scid, + @outgoing_time_lock, + @amt_to_forward, + @meta_data +) +RETURNING id; + +-- name: InsertRouteHopMpp :exec +INSERT INTO payment_route_hop_mpp ( + hop_id, + payment_addr, + total_msat +) +VALUES ( + @hop_id, + @payment_addr, + @total_msat +); + +-- name: InsertRouteHopAmp :exec +INSERT INTO payment_route_hop_amp ( + hop_id, + root_share, + set_id, + child_index +) +VALUES ( + @hop_id, + @root_share, + @set_id, + @child_index +); + +-- name: InsertRouteHopBlinded :exec +INSERT INTO payment_route_hop_blinded ( + hop_id, + encrypted_data, + blinding_point, + blinded_path_total_amt +) +VALUES ( + @hop_id, + @encrypted_data, + @blinding_point, + @blinded_path_total_amt +); + +-- name: InsertPaymentHopCustomRecord :exec +INSERT INTO payment_hop_custom_records ( + hop_id, + key, + value +) +VALUES ( + @hop_id, + @key, + @value +); + +-- name: SettleAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + settle_preimage +) +VALUES ( + @attempt_index, + @resolution_time, + @resolution_type, + @settle_preimage +); + +-- name: FailAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + failure_source_index, + htlc_fail_reason, + failure_msg +) +VALUES ( + @attempt_index, + @resolution_time, + @resolution_type, + @failure_source_index, + @htlc_fail_reason, + @failure_msg +); + +-- name: FailPayment :execresult +UPDATE payments SET fail_reason = $1 WHERE payment_identifier = $2;