From 923d48f9db0ca87bed40c3a145875d5ffe4303a3 Mon Sep 17 00:00:00 2001 From: f321x Date: Fri, 28 Nov 2025 16:22:22 +0100 Subject: [PATCH 1/2] lnworker: differentiate PaymentInfo by direction Allows storing two different payment info of the same payment hash by including the direction into the db key. We create and store PaymentInfo for sending attempts and for requests (receiving), if we try to pay ourself (e.g. through a channel rebalance) the checks in `save_payment_info` would prevent this and throw an exception. By storing the PaymentInfos of outgoing and incoming payments separately in the db this collision is avoided and it makes it easier to reason about which PaymentInfo belongs where. --- electrum/commands.py | 21 ++++----- electrum/gui/qml/qerequestdetails.py | 4 +- electrum/gui/qt/send_tab.py | 3 +- electrum/lnpeer.py | 4 +- electrum/lnutil.py | 1 + electrum/lnworker.py | 70 ++++++++++++++++------------ electrum/plugins/nwc/nwcserver.py | 10 ++-- electrum/submarine_swaps.py | 14 ++++-- electrum/wallet.py | 8 ++-- electrum/wallet_db.py | 21 ++++++++- tests/test_commands.py | 5 +- tests/test_lnpeer.py | 52 ++++++++++----------- 12 files changed, 125 insertions(+), 88 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 6aef50cecda9..6482b0b3dbac 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -70,7 +70,7 @@ ) from .address_synchronizer import TX_HEIGHT_LOCAL from .mnemonic import Mnemonic -from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, MIN_FINAL_CLTV_DELTA_ACCEPTED, +from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, RECEIVED, MIN_FINAL_CLTV_DELTA_ACCEPTED, PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE) from .plugin import run_hook, DeviceMgr, Plugins from .version import ELECTRUM_VERSION @@ -1402,7 +1402,7 @@ async def add_hold_invoice( arg:int:min_final_cltv_expiry_delta:Optional min final cltv expiry delta (default: 294 blocks) """ assert len(payment_hash) == 64, f"Invalid payment hash length: {len(payment_hash)} != 64" - assert payment_hash not in wallet.lnworker.payment_info, "Payment hash already used!" + assert not wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), "Payment hash already used!" assert payment_hash not in wallet.lnworker.dont_expire_htlcs, "Payment hash already used!" assert wallet.lnworker.get_preimage(bfh(payment_hash)) is None, "Already got a preimage for this payment hash!" assert MIN_FINAL_CLTV_DELTA_ACCEPTED < min_final_cltv_expiry_delta < 576, "Use a sane min_final_cltv_expiry_delta value" @@ -1417,7 +1417,7 @@ async def add_hold_invoice( min_final_cltv_delta=min_final_cltv_expiry_delta, exp_delay=expiry, ) - info = wallet.lnworker.get_payment_info(bfh(payment_hash)) + info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) lnaddr, invoice = wallet.lnworker.get_bolt11_invoice( payment_info=info, message=memo, @@ -1443,12 +1443,11 @@ async def settle_hold_invoice(self, preimage: str, wallet: Abstract_Wallet = Non assert len(preimage) == 64, f"Invalid payment_hash length: {len(preimage)} != 64" payment_hash: str = crypto.sha256(bfh(preimage)).hex() assert payment_hash not in wallet.lnworker._preimages, f"Invoice {payment_hash=} already settled" - assert payment_hash in wallet.lnworker.payment_info, \ - f"Couldn't find lightning invoice for {payment_hash=}" + info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) + assert info, f"Couldn't find lightning invoice for {payment_hash=}" assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"Invoice {payment_hash=} not a hold invoice?" assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \ f"MPP incomplete, cannot settle hold invoice {payment_hash} yet" - info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0) wallet.lnworker.save_preimage(bfh(payment_hash), bfh(preimage)) util.trigger_callback('wallet_updated', wallet) @@ -1464,13 +1463,13 @@ async def cancel_hold_invoice(self, payment_hash: str, wallet: Abstract_Wallet = arg:str:payment_hash:Payment hash in hex of the hold invoice """ - assert payment_hash in wallet.lnworker.payment_info, \ + assert wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), \ f"Couldn't find lightning invoice for payment hash {payment_hash}" assert payment_hash not in wallet.lnworker._preimages, "Cannot cancel anymore, preimage already given." assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"{payment_hash=} not a hold invoice?" # set to PR_UNPAID so it can get deleted - wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID) - wallet.lnworker.delete_payment_info(payment_hash) + wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID, direction=RECEIVED) + wallet.lnworker.delete_payment_info(payment_hash, direction=RECEIVED) wallet.set_label(payment_hash, None) del wallet.lnworker.dont_expire_htlcs[payment_hash] while wallet.lnworker.is_complete_mpp(bfh(payment_hash)): @@ -1496,7 +1495,7 @@ async def check_hold_invoice(self, payment_hash: str, wallet: Abstract_Wallet = arg:str:payment_hash:Payment hash in hex of the hold invoice """ assert len(payment_hash) == 64, f"Invalid payment_hash length: {len(payment_hash)} != 64" - info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) + info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) is_complete_mpp: bool = wallet.lnworker.is_complete_mpp(bfh(payment_hash)) amount_sat = (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) // 1000 result = { @@ -1518,7 +1517,7 @@ async def check_hold_invoice(self, payment_hash: str, wallet: Abstract_Wallet = elif wallet.lnworker.get_preimage_hex(payment_hash) is not None: result["status"] = "settled" plist = wallet.lnworker.get_payments(status='settled')[bfh(payment_hash)] - _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(info, plist) + _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(None, plist) result["received_amount_sat"] = amount_msat // 1000 result['preimage'] = wallet.lnworker.get_preimage_hex(payment_hash) if info is not None: diff --git a/electrum/gui/qml/qerequestdetails.py b/electrum/gui/qml/qerequestdetails.py index d0d5f29f4f2e..288ef167c827 100644 --- a/electrum/gui/qml/qerequestdetails.py +++ b/electrum/gui/qml/qerequestdetails.py @@ -8,7 +8,7 @@ from electrum.invoices import ( PR_UNPAID, PR_EXPIRED, PR_UNKNOWN, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED, LN_EXPIRY_NEVER ) -from electrum.lnutil import MIN_FUNDING_SAT +from electrum.lnutil import MIN_FUNDING_SAT, RECEIVED from electrum.lnurl import LNURL3Data, request_lnurl_withdraw_callback, LNURLError from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierType from electrum.i18n import _ @@ -237,7 +237,7 @@ def lnurlRequestWithdrawal(self, amount_sat: int) -> None: address=None, ) req = self._wallet.wallet.get_request(key) - info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash) + info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) _lnaddr, b11_invoice = self._wallet.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=req.get_message(), diff --git a/electrum/gui/qt/send_tab.py b/electrum/gui/qt/send_tab.py index 6c927b11dca1..2cbcdec9cba3 100644 --- a/electrum/gui/qt/send_tab.py +++ b/electrum/gui/qt/send_tab.py @@ -18,6 +18,7 @@ NotEnoughFunds, NoDynamicFeeEstimates, parse_max_spend, UserCancelled, ChoiceItem, UserFacingException, ) +from electrum.lnutil import RECEIVED from electrum.invoices import PR_PAID, Invoice, PR_BROADCASTING, PR_BROADCAST from electrum.transaction import Transaction, PartialTxInput, PartialTxOutput from electrum.network import TxBroadcastError, BestEffortRequestFailed @@ -979,7 +980,7 @@ def request_lnurl_withdraw_dialog(self, lnurl_data: LNURL3Data): address=None, ) req = self.wallet.get_request(key) - info = self.wallet.lnworker.get_payment_info(req.payment_hash) + info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) _lnaddr, b11_invoice = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=req.get_message(), diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 1e8da558b468..591ce47ce383 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2221,7 +2221,7 @@ def _check_unfulfilled_htlc( outer_onion_payment_secret=payment_secret_from_onion, ) - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) if info is None: _log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}") raise exc_incorrect_or_unknown_pd @@ -3115,7 +3115,7 @@ def _check_unfulfilled_htlc_set( return None, None, fwd_cb # -- from here on it's assumed this set is a payment for us (not something to forward) -- - payment_info = self.lnworker.get_payment_info(payment_hash) + payment_info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) if payment_info is None: _log_fail_reason(f"payment info has been deleted") return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 0d30d3718e3b..70256a74d343 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1089,6 +1089,7 @@ def __neg__(self) -> 'HTLCOwner': return HTLCOwner(super().__neg__()) +# part of lightning_payments db keys class Direction(IntEnum): SENT = -1 # in the context of HTLCs: "offered" HTLCs RECEIVED = 1 # in the context of HTLCs: "received" HTLCs diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3074ababa35..f96640cb1a23 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -45,7 +45,8 @@ FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING, FEE_LN_MINIMUM_ETA_TARGET ) -from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice +from .invoices import (Invoice, Request, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, + BaseInvoice) from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress from .bip32 import BIP32Node from .address_synchronizer import TX_HEIGHT_LOCAL @@ -120,7 +121,7 @@ class PaymentInfo: """Information required to handle incoming htlcs for a payment request""" payment_hash: bytes amount_msat: Optional[int] - direction: int + direction: lnutil.Direction status: int min_final_cltv_delta: int expiry_delay: int @@ -142,6 +143,13 @@ def validate(self): def __post_init__(self): self.validate() + @property + def db_key(self) -> str: + return self.calc_db_key(payment_hash_hex=self.payment_hash.hex(), direction=self.direction) + + @classmethod + def calc_db_key(cls, *, payment_hash_hex: str, direction: lnutil.Direction) -> str: + return f"{payment_hash_hex}:{int(direction)}" SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -887,8 +895,8 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv): LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None - # lightning_payments: RHASH -> amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts - self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]] + # lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts + self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]] self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse @@ -1104,7 +1112,7 @@ def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]: return out def get_payment_value( - self, info: Optional['PaymentInfo'], + self, sent_info: Optional['PaymentInfo'], plist: List[HTLCWithStatus] ) -> Tuple[PaymentDirection, int, Optional[int], int]: """ fee_msat is included in amount_msat""" @@ -1112,7 +1120,7 @@ def get_payment_value( amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist) if all(x.direction == SENT for x in plist): direction = PaymentDirection.SENT - fee_msat = (- info.amount_msat - amount_msat) if info else None + fee_msat = (- sent_info.amount_msat - amount_msat) if sent_info else None elif all(x.direction == RECEIVED for x in plist): direction = PaymentDirection.RECEIVED fee_msat = None @@ -1135,12 +1143,12 @@ def get_lightning_history(self) -> Dict[str, LightningHistoryItem]: if len(plist) == 0: continue key = payment_hash.hex() - info = self.get_payment_info(payment_hash) + sent_info = self.get_payment_info(payment_hash, direction=SENT) # note: just after successfully paying an invoice using MPP, amount and fee values might be shifted # temporarily: the amount only considers 'settled' htlcs (see plist above), but we might also # have some inflight htlcs still. Until all relevant htlcs settle, the amount will be lower than # expected and the fee higher (the inflight htlcs will be effectively counted as fees). - direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist) + direction, amount_msat, fee_msat, timestamp = self.get_payment_value(sent_info, plist) label = self.wallet.get_label_for_rhash(key) if not label and direction == PaymentDirection.FORWARDING: label = _('Forwarding') @@ -1597,7 +1605,7 @@ async def pay_invoice( invoice_features = lnaddr.get_features() r_tags = lnaddr.get_routing_info('r') amount_to_pay = lnaddr.get_amount_msat() - status = self.get_payment_status(payment_hash) + status = self.get_payment_status(payment_hash, direction=SENT) if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) if status == PR_INFLIGHT: @@ -2491,13 +2499,13 @@ def get_preimage_hex(self, payment_hash: str) -> Optional[str]: preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b"" return preimage_bytes.hex() or None - def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]: + def get_payment_info(self, payment_hash: bytes, *, direction: lnutil.Direction) -> Optional[PaymentInfo]: """returns None if payment_hash is a payment we are forwarding""" - key = payment_hash.hex() + key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash.hex(), direction=direction) with self.lock: if key in self.payment_info: stored_tuple = self.payment_info[key] - amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple + amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple return PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, @@ -2542,7 +2550,7 @@ def unregister_hold_invoice(self, payment_hash: bytes): def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: assert info.status in SAVED_PR_STATUS with self.lock: - if old_info := self.get_payment_info(payment_hash=info.payment_hash): + if old_info := self.get_payment_info(payment_hash=info.payment_hash, direction=info.direction): if info == old_info: return # already saved if info.direction == SENT: @@ -2551,8 +2559,8 @@ def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception(f"payment_hash already in use: {info=} != {old_info=}") - key = info.payment_hash.hex() - self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0 + v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts + self.payment_info[info.db_key] = v if write_to_disk: self.wallet.save_db() @@ -2698,13 +2706,15 @@ def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: self.active_forwardings.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None) - def get_payment_status(self, payment_hash: bytes) -> int: - info = self.get_payment_info(payment_hash) + def get_payment_status(self, payment_hash: bytes, *, direction: lnutil.Direction) -> int: + info = self.get_payment_info(payment_hash, direction=direction) return info.status if info else PR_UNPAID def get_invoice_status(self, invoice: BaseInvoice) -> int: invoice_id = invoice.rhash - status = self.get_payment_status(bfh(invoice_id)) + assert isinstance(invoice, (Request, Invoice)), type(invoice) + direction = RECEIVED if isinstance(invoice, Request) else SENT + status = self.get_payment_status(bfh(invoice_id), direction=direction) if status == PR_UNPAID and invoice_id in self.inflight_payments: return PR_INFLIGHT # status may be PR_FAILED @@ -2718,24 +2728,24 @@ def set_invoice_status(self, key: str, status: int) -> None: elif key in self.inflight_payments: self.inflight_payments.remove(key) if status in SAVED_PR_STATUS: - self.set_payment_status(bfh(key), status) + self.set_payment_status(bfh(key), status, direction=SENT) util.trigger_callback('invoice_status', self.wallet, key, status) self.logger.info(f"set_invoice_status {key}: {status}") # liquidity changed self.clear_invoices_cache() def set_request_status(self, payment_hash: bytes, status: int) -> None: - if self.get_payment_status(payment_hash) == status: + if self.get_payment_status(payment_hash, direction=RECEIVED) == status: return - self.set_payment_status(payment_hash, status) + self.set_payment_status(payment_hash, status, direction=RECEIVED) request_id = payment_hash.hex() req = self.wallet.get_request(request_id) if req is None: return util.trigger_callback('request_status', self.wallet, request_id, status) - def set_payment_status(self, payment_hash: bytes, status: int) -> None: - info = self.get_payment_info(payment_hash) + def set_payment_status(self, payment_hash: bytes, status: int, *, direction: lnutil.Direction) -> None: + info = self.get_payment_info(payment_hash, direction=direction) if info is None: # if we are forwarding return @@ -2930,14 +2940,15 @@ def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=No cltv_delta)])) return routing_hints - def delete_payment_info(self, payment_hash_hex: str): + def delete_payment_info(self, payment_hash_hex: str, *, direction: lnutil.Direction): # This method is called when an invoice or request is deleted by the user. # The GUI only lets the user delete invoices or requests that have not been paid. # Once an invoice/request has been paid, it is part of the history, # and get_lightning_history assumes that payment_info is there. - assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID + assert self.get_payment_status(bytes.fromhex(payment_hash_hex), direction=direction) != PR_PAID with self.lock: - self.payment_info.pop(payment_hash_hex, None) + key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash_hex, direction=direction) + self.payment_info.pop(key, None) def get_balance(self, *, frozen=False) -> Decimal: with self.lock: @@ -3185,7 +3196,7 @@ async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msa amount_msat=amount_msat, exp_delay=3600, ) - info = self.get_payment_info(payment_hash) + info = self.get_payment_info(payment_hash, direction=RECEIVED) lnaddr, invoice = self.get_bolt11_invoice( payment_info=info, message='rebalance', @@ -3804,14 +3815,13 @@ def _maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created(self, pa - Alice sends htlc A->B->C, for 1 sat, with HASH1 - Bob must not release the preimage of HASH1 """ - payment_info = self.get_payment_info(payment_hash) - is_our_payreq = payment_info and payment_info.direction == RECEIVED + payment_info = self.get_payment_info(payment_hash, direction=RECEIVED) # note: If we don't have the preimage for a payment request, then it must be a hold invoice. # Hold invoices are created by other parties (e.g. a counterparty initiating a submarine swap), # and it is the other party choosing the payment_hash. If we failed HTLCs with payment_hashes colliding # with hold invoices, then a party that can make us save a hold invoice for an arbitrary hash could # also make us fail arbitrary HTLCs. - return bool(is_our_payreq and self.get_preimage(payment_hash)) + return bool(payment_info and self.get_preimage(payment_hash)) def create_onion_for_route( self, *, diff --git a/electrum/plugins/nwc/nwcserver.py b/electrum/plugins/nwc/nwcserver.py index c2cd925f59ce..1ed021ff8f27 100644 --- a/electrum/plugins/nwc/nwcserver.py +++ b/electrum/plugins/nwc/nwcserver.py @@ -42,6 +42,7 @@ get_running_loop from electrum.invoices import Invoice, Request, PR_UNKNOWN, PR_PAID, BaseInvoice, PR_INFLIGHT from electrum import constants +from electrum.lnutil import RECEIVED if TYPE_CHECKING: from aiohttp_socks import ProxyConnector @@ -480,7 +481,7 @@ async def handle_make_invoice(self, request_event: nEvent, params: dict): address=None ) req: Request = self.wallet.get_request(key) - info = self.wallet.lnworker.get_payment_info(req.payment_hash) + info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) try: lnaddr, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, @@ -537,7 +538,7 @@ async def handle_lookup_invoice(self, request_event: nEvent, params: dict): b11 = invoice.lightning_invoice elif self.wallet.get_request(invoice.rhash): direction = "incoming" - info = self.wallet.lnworker.get_payment_info(invoice.payment_hash) + info = self.wallet.lnworker.get_payment_info(invoice.payment_hash, direction=RECEIVED) _, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=invoice.message, @@ -747,7 +748,7 @@ def on_event_request_status(self, wallet, key, status): request: Optional[Request] = self.wallet.get_request(key) if not request or not request.is_lightning() or not status == PR_PAID: return - info = self.wallet.lnworker.get_payment_info(request.payment_hash) + info = self.wallet.lnworker.get_payment_info(request.payment_hash, direction=RECEIVED) _, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=request.message, @@ -947,7 +948,8 @@ def get_payment_info(self, payment_hash: str) \ payments = self.wallet.lnworker.get_payments(status='settled') plist = payments.get(payment_hash) if plist: - info = self.wallet.lnworker.get_payment_info(payment_hash) + direction = plist[0].direction + info = self.wallet.lnworker.get_payment_info(payment_hash, direction=direction) if info: dir, amount, fee, ts = self.wallet.lnworker.get_payment_value(info, plist) fee = abs(fee) if fee else None diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 5d6275f8e24b..8ea53df598a3 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -401,7 +401,7 @@ def _fail_swap(self, swap: SwapData, reason: str): if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks: # unregister_hold_invoice will fail pending htlcs if there is no preimage available self.lnworker.unregister_hold_invoice(swap.payment_hash) - self.lnworker.delete_payment_info(swap.payment_hash.hex()) + self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.RECEIVED) self.lnworker.clear_invoices_cache() self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_funded(): @@ -413,9 +413,13 @@ def _fail_swap(self, swap: SwapData, reason: str): self._swaps_by_lockup_address.pop(swap.lockup_address, None) if swap.prepay_hash is not None: self._prepayments.pop(swap.prepay_hash, None) - if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID: - self.lnworker.delete_payment_info(swap.prepay_hash.hex()) + if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.RECEIVED) != PR_PAID: + self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.RECEIVED) self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) + if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.SENT) != PR_PAID: + self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.SENT) + if self.lnworker.get_payment_status(swap.payment_hash, direction=lnutil.SENT) != PR_PAID: + self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.SENT) @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -693,7 +697,7 @@ def add_normal_swap( min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=lnutil.RECEIVED) lnaddr1, invoice = self.lnworker.get_bolt11_invoice( payment_info=info, message='Submarine swap', @@ -712,7 +716,7 @@ def add_normal_swap( min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) - info = self.lnworker.get_payment_info(prepay_hash) + info = self.lnworker.get_payment_info(prepay_hash, direction=lnutil.RECEIVED) lnaddr2, prepay_invoice = self.lnworker.get_bolt11_invoice( payment_info=info, message='Submarine swap prepayment', diff --git a/electrum/wallet.py b/electrum/wallet.py index 32f60119b188..88002e96af76 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -76,7 +76,7 @@ from .contacts import Contacts from .mnemonic import Mnemonic from .lnworker import LNWallet -from .lnutil import MIN_FUNDING_SAT +from .lnutil import MIN_FUNDING_SAT, RECEIVED, SENT from .lntransport import extract_nodeid from .descriptor import Descriptor from .txbatcher import TxBatcher @@ -3014,7 +3014,7 @@ def get_bolt11_invoice(self, req: Request) -> str: return '' amount_msat = req.get_amount_msat() or None assert (amount_msat is None or amount_msat > 0), amount_msat - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) assert info.amount_msat == amount_msat, f"{info.amount_msat=} != {amount_msat=}" lnaddr, invoice = self.lnworker.get_bolt11_invoice( payment_info=info, @@ -3074,7 +3074,7 @@ def delete_request(self, request_id, *, write_to_disk: bool = True): if addr := req.get_address(): self._requests_addr_to_key[addr].discard(request_id) if req.is_lightning() and self.lnworker: - self.lnworker.delete_payment_info(req.rhash) + self.lnworker.delete_payment_info(req.rhash, direction=RECEIVED) if write_to_disk: self.save_db() @@ -3084,7 +3084,7 @@ def delete_invoice(self, invoice_id, *, write_to_disk: bool = True): if inv is None: return if inv.is_lightning() and self.lnworker: - self.lnworker.delete_payment_info(inv.rhash) + self.lnworker.delete_payment_info(inv.rhash, direction=SENT) if write_to_disk: self.save_db() diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index bf22f03fbd21..325eb36f3def 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ def __init__(self, wallet_db: 'WalletDB'): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 64 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -235,6 +235,7 @@ def upgrade(self): self._convert_version_61() self._convert_version_62() self._convert_version_63() + self._convert_version_64() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1269,6 +1270,24 @@ def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Op self.data['seed_version'] = 63 + def _convert_version_64(self): + """Key payment_info by "rhash:direction" instead of just rhash to allow storing a PaymentInfo + for each direction""" + if not self._is_upgrade_method_needed(63, 63): + return + + new_payment_infos = {} + old_payment_infos = self.data.get('lightning_payments', {}) + for payment_hash, old_values in old_payment_infos.items(): + amount_msat, direction, status, min_final_cltv_expiry, expiry, creation_ts = old_values + # drop direction + new_values = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts) + new_key = f"{payment_hash}:{direction}" + new_payment_infos[new_key] = new_values # save new entry + + self.data['lightning_payments'] = new_payment_infos + self.data['seed_version'] = 64 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index 3016263e1b31..969bfb02433c 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,6 +11,7 @@ import electrum from electrum.commands import Commands, eval_bool from electrum import storage, wallet +from electrum.lnutil import RECEIVED from electrum.lnworker import RecvMPPResolution from electrum.wallet import Abstract_Wallet from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED @@ -509,7 +510,7 @@ async def test_hold_invoice_commands(self, mock_save_db): ) invoice = lndecode(invoice=result['invoice']) assert invoice.paymenthash.hex() == payment_hash - assert payment_hash in wallet.lnworker.payment_info + assert wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED) assert payment_hash in wallet.lnworker.dont_expire_htlcs assert invoice.get_amount_sat() == 10000 assert invoice.get_description() == "test" @@ -520,7 +521,7 @@ async def test_hold_invoice_commands(self, mock_save_db): payment_hash=payment_hash, wallet=wallet, ) - assert payment_hash not in wallet.lnworker.payment_info + assert not wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED) assert payment_hash not in wallet.lnworker.dont_expire_htlcs assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == "" assert cancel_result['cancelled'] == payment_hash diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 3fe6a3729d2a..57def01cb495 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -865,10 +865,10 @@ async def _test_simple_payment( p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) results = {} async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await w1.pay_invoice(pay_req) if result is True: - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) results[lnaddr] = PaymentDone() else: results[lnaddr] = PaymentFailure() @@ -988,7 +988,7 @@ async def run_test(test_trampoline): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]] b11 = lnencode(lnaddr, w2.node_keypair.privkey) @@ -1079,7 +1079,7 @@ async def try_pay_invoice_twice(pay_req: Invoice, w1=w1): result, log = await w1.pay_invoice(pay_req) assert result is True # now pay the same invoice again, the payment should be rejected by w2 - w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID) + w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID, direction=lnutil.SENT) result, log = await w1.pay_invoice(pay_req) if not result: # w1.pay_invoice returned a payment failure as the payment got rejected by w2 @@ -1224,8 +1224,8 @@ async def test_payment_recv_mpp_confusion1(self): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED)) route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route p1.pay( @@ -1297,7 +1297,7 @@ async def test_payment_recv_mpp_confusion2(self): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route p1.pay( @@ -1997,11 +1997,11 @@ async def run_test(test_trampoline, test_failure): w2.dont_settle_htlcs[pay_req.rhash] = None async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) if result is True: self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs) - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) return PaymentDone() else: self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) @@ -2067,10 +2067,10 @@ async def run_test(test_trampoline, test_expiry): w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20 async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) if result is True: - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) return PaymentDone() else: self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) @@ -2210,12 +2210,12 @@ def mocked_split_amount_normal(total_amount: int, num_parts: int) -> List[int]: return split_amount_normal(total_amount, num_parts) async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) with mock.patch('electrum.mpp_split.split_amount_normal', side_effect=mocked_split_amount_normal): result, log = await graph.workers['bob'].pay_invoice(pay_req) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) async def f(): async with OldTaskGroup() as group: @@ -2242,10 +2242,10 @@ async def test_payment_multihop(self): graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) raise PaymentDone() async def f(): async with OldTaskGroup() as group: @@ -2309,10 +2309,10 @@ async def test_payment_multihop_temp_node_failure(self): graph.workers['carol'].network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE = True peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertFalse(result) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) raise PaymentDone() async def f(): @@ -2336,11 +2336,11 @@ async def test_payment_multihop_route_around_failure(self): async def pay(lnaddr, pay_req): self.assertEqual(500000000000, graph.channels[('alice', 'bob')].balance(LOCAL)) self.assertEqual(500000000000, graph.channels[('dave', 'bob')].balance(LOCAL)) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=2) self.assertEqual(2, len(log)) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual([graph.channels[('alice', 'carol')].short_channel_id, graph.channels[('carol', 'dave')].short_channel_id], [edge.short_channel_id for edge in log[0].route]) self.assertEqual([graph.channels[('alice', 'bob')].short_channel_id, graph.channels[('bob', 'dave')].short_channel_id], @@ -2436,11 +2436,11 @@ async def test_payment_with_temp_channel_failure_and_liquidity_hints(self): amount_to_pay = 100_000_000 peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=3) self.assertTrue(result) self.assertEqual(2, len(log)) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code) liquidity_hints = graph.workers['alice'].network.path_finder.liquidity_hints @@ -2507,14 +2507,14 @@ async def pay( assert alice_w.network.channel_db is not None lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay) self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure) - self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await alice_w.pay_invoice(pay_req, attempts=attempts) if not bob_forwarding: # reset to previous state, sleep 2s so that the second htlc can time out graph.workers['bob'].enable_htlc_forwarding = True await asyncio.sleep(2) if result: - self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) # check mpp is cleaned up async with OldTaskGroup() as g: for peer in peers: @@ -2642,7 +2642,7 @@ async def _run_trampoline_payment( dest_w = graph.workers[destination_name] async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await sender_w.pay_invoice(pay_req, attempts=attempts) async with OldTaskGroup() as g: for peer in peers: @@ -2653,7 +2653,7 @@ async def pay(lnaddr, pay_req): for peer in peers: self.assertEqual(len(peer.lnworker.active_forwardings), 0) if result: - self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) raise PaymentDone() else: raise NoPathFound() @@ -2875,7 +2875,7 @@ async def test_payment_with_malformed_onion(self): peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertEqual(OnionFailureCode.INVALID_ONION_VERSION, log[0].failure_msg.code) self.assertFalse(result, msg=log) From df612fa010d0bf1aef6309028940075d71ac5392 Mon Sep 17 00:00:00 2001 From: f321x Date: Fri, 28 Nov 2025 17:06:07 +0100 Subject: [PATCH 2/2] lnworker: allow overwriting amount of sent payment info Allows replacing a saved `PaymentInfo` of `SENT` direction if the old one is not yet paid. This allows the user to retry paying a 0 amount invoice with different amount if the previous attempt failed. --- electrum/lnworker.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index f96640cb1a23..9a37fef6fd87 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -2553,9 +2553,15 @@ def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> if old_info := self.get_payment_info(payment_hash=info.payment_hash, direction=info.direction): if info == old_info: return # already saved - if info.direction == SENT: - # allow saving of newer PaymentInfo if it is a sending attempt - old_info = dataclasses.replace(old_info, creation_ts=info.creation_ts) + if info.direction == SENT and old_info.status in (PR_UNPAID, PR_FAILED): + # allow saving of newer PaymentInfo if it is a sending attempt and the previous + # payment failed or was not yet attempted + old_info = dataclasses.replace( + old_info, + creation_ts=info.creation_ts, + status=info.status, + amount_msat=info.amount_msat, # might retrying to pay 0 amount invoice + ) if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception(f"payment_hash already in use: {info=} != {old_info=}")