diff --git a/multiversx_sdk/network_providers/account_awaiter.py b/multiversx_sdk/network_providers/account_awaiter.py new file mode 100644 index 00000000..259808f5 --- /dev/null +++ b/multiversx_sdk/network_providers/account_awaiter.py @@ -0,0 +1,96 @@ +import logging +import time +from typing import Callable, Optional, Protocol, Union + +from multiversx_sdk.core.address import Address +from multiversx_sdk.network_providers.constants import ( + DEFAULT_ACCOUNT_AWAITING_PATIENCE_IN_MILLISECONDS, + DEFAULT_ACCOUNT_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS, + DEFAULT_ACCOUNT_AWAITING_TIMEOUT_IN_MILLISECONDS) +from multiversx_sdk.network_providers.errors import \ + ExpectedAccountConditionNotReachedError +from multiversx_sdk.network_providers.resources import AccountOnNetwork + +ONE_SECOND_IN_MILLISECONDS = 1000 + +logger = logging.getLogger("account_awaiter") + + +class IAccountFetcher(Protocol): + def get_account(self, address: Address) -> AccountOnNetwork: + ... + + +class AccountAwaiter: + """AccountAwaiter allows one to await until a specific event occurs on a given address.""" + + def __init__(self, + fetcher: IAccountFetcher, + polling_interval_in_milliseconds: Optional[int] = None, + timeout_interval_in_milliseconds: Optional[int] = None, + patience_time_in_milliseconds: Optional[int] = None) -> None: + """ + Args: + fetcher (IAccountFetcher): Used to fetch the account of the network. + polling_interval_in_milliseconds (Optional[int]): The polling interval, in milliseconds. + timeout_interval_in_milliseconds (Optional[int]): The timeout, in milliseconds. + patience_time_in_milliseconds (Optional[int]): The patience, an extra time (in milliseconds) to wait, after the account has reached its desired condition. + """ + self.fetcher = fetcher + + if polling_interval_in_milliseconds is None: + self.polling_interval_in_milliseconds = DEFAULT_ACCOUNT_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS + else: + self.polling_interval_in_milliseconds = polling_interval_in_milliseconds + + if timeout_interval_in_milliseconds is None: + self.timeout_interval_in_milliseconds = DEFAULT_ACCOUNT_AWAITING_TIMEOUT_IN_MILLISECONDS + else: + self.timeout_interval_in_milliseconds = timeout_interval_in_milliseconds + + if patience_time_in_milliseconds is None: + self.patience_time_in_milliseconds = DEFAULT_ACCOUNT_AWAITING_PATIENCE_IN_MILLISECONDS + else: + self.patience_time_in_milliseconds = patience_time_in_milliseconds + + def await_on_condition(self, address: Address, condition: Callable[[AccountOnNetwork], bool]) -> AccountOnNetwork: + """Waits until the condition is satisfied.""" + def do_fetch(): + return self.fetcher.get_account(address) + + return self._await_conditionally( + is_satisfied=condition, + do_fetch=do_fetch, + error=ExpectedAccountConditionNotReachedError() + ) + + def _await_conditionally(self, + is_satisfied: Callable[[AccountOnNetwork], bool], + do_fetch: Callable[[], AccountOnNetwork], + error: Exception) -> AccountOnNetwork: + is_condition_satisfied = False + fetched_data: Union[AccountOnNetwork, None] = None + max_number_of_retries = self.timeout_interval_in_milliseconds // self.polling_interval_in_milliseconds + + number_of_retries = 0 + while number_of_retries < max_number_of_retries: + try: + fetched_data = do_fetch() + is_condition_satisfied = is_satisfied(fetched_data) + + if is_condition_satisfied: + break + except Exception as ex: + raise ex + + number_of_retries += 1 + time.sleep(self.polling_interval_in_milliseconds / ONE_SECOND_IN_MILLISECONDS) + + if fetched_data is None or not is_condition_satisfied: + raise error + + if self.patience_time_in_milliseconds: + time.sleep(self.patience_time_in_milliseconds / ONE_SECOND_IN_MILLISECONDS) + return do_fetch() + + return fetched_data diff --git a/multiversx_sdk/network_providers/account_awaiter_test.py b/multiversx_sdk/network_providers/account_awaiter_test.py new file mode 100644 index 00000000..d0e3752d --- /dev/null +++ b/multiversx_sdk/network_providers/account_awaiter_test.py @@ -0,0 +1,72 @@ +import pytest + +from multiversx_sdk.core.address import Address +from multiversx_sdk.core.transaction import Transaction +from multiversx_sdk.core.transaction_computer import TransactionComputer +from multiversx_sdk.network_providers.account_awaiter import AccountAwaiter +from multiversx_sdk.network_providers.api_network_provider import \ + ApiNetworkProvider +from multiversx_sdk.network_providers.resources import AccountOnNetwork +from multiversx_sdk.testutils.mock_network_provider import ( + MockNetworkProvider, TimelinePointMarkCompleted, TimelinePointWait) +from multiversx_sdk.testutils.utils import create_account_egld_balance +from multiversx_sdk.testutils.wallets import load_wallets + + +class TestTransactionAwaiter: + provider = MockNetworkProvider() + watcher = AccountAwaiter( + fetcher=provider, + polling_interval_in_milliseconds=42, + timeout_interval_in_milliseconds=42 * 42, + patience_time_in_milliseconds=0 + ) + + @pytest.mark.only + def test_await_on_balance_increase(self): + alice = Address.new_from_bech32("erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th") + initial_balance = self.provider.get_account(alice).balance + + # adds 7 EGLD to the account balance + self.provider.mock_account_balance_timeline_by_address( + alice, + [TimelinePointWait(40), TimelinePointWait(40), TimelinePointWait(45), TimelinePointMarkCompleted()] + ) + + def condition(account: AccountOnNetwork): + return account.balance == initial_balance + create_account_egld_balance(7) + + account = self.watcher.await_on_condition(alice, condition) + assert account.balance == create_account_egld_balance(1007) + + @pytest.mark.networkInteraction + def test_on_network(self): + alice = load_wallets()["alice"] + alice_address = Address.new_from_bech32(alice.label) + frank = Address.new_from_bech32("erd1kdl46yctawygtwg2k462307dmz2v55c605737dp3zkxh04sct7asqylhyv") + + api = ApiNetworkProvider("https://devnet-api.multiversx.com") + watcher = AccountAwaiter(fetcher=api) + tx_computer = TransactionComputer() + value = 100_000 + + transaction = Transaction( + sender=alice_address, + receiver=frank, + gas_limit=50000, + chain_id="D", + value=value + ) + transaction.nonce = api.get_account(alice_address).nonce + transaction.signature = alice.secret_key.sign(tx_computer.compute_bytes_for_signing(transaction)) + + initial_balance = api.get_account(frank).balance + print("initial:", initial_balance) + + def condition(account: AccountOnNetwork): + return account.balance == initial_balance + value + + api.send_transaction(transaction) + + account_on_network = watcher.await_on_condition(frank, condition) + assert account_on_network.balance == initial_balance + value diff --git a/multiversx_sdk/network_providers/constants.py b/multiversx_sdk/network_providers/constants.py index c3ea0763..66cc1618 100644 --- a/multiversx_sdk/network_providers/constants.py +++ b/multiversx_sdk/network_providers/constants.py @@ -2,5 +2,9 @@ DEFAULT_TRANSACTION_AWAITING_TIMEOUT_IN_MILLISECONDS = 15 * DEFAULT_TRANSACTION_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS DEFAULT_TRANSACTION_AWAITING_PATIENCE_IN_MILLISECONDS = 3000 +DEFAULT_ACCOUNT_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS = 6000 +DEFAULT_ACCOUNT_AWAITING_TIMEOUT_IN_MILLISECONDS = 15 * DEFAULT_TRANSACTION_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS +DEFAULT_ACCOUNT_AWAITING_PATIENCE_IN_MILLISECONDS = 0 + BASE_USER_AGENT = "multiversx-sdk-py" UNKNOWN_CLIENT_NAME = "unknown" diff --git a/multiversx_sdk/network_providers/errors.py b/multiversx_sdk/network_providers/errors.py index f17af60b..d30232b3 100644 --- a/multiversx_sdk/network_providers/errors.py +++ b/multiversx_sdk/network_providers/errors.py @@ -8,14 +8,14 @@ def __init__(self, url: str, data: Any): self.data = data -class ExpectedTransactionStatusNotReached(Exception): +class ExpectedTransactionStatusNotReachedError(Exception): def __init__(self) -> None: super().__init__("The expected transaction status was not reached") -class IsCompletedFieldMissingOnTransaction(Exception): +class ExpectedAccountConditionNotReachedError(Exception): def __init__(self) -> None: - super().__init__("The transaction awaiter requires the `is_completed` property to be defined on the transaction object. Perhaps you've used `ProxyNetworkProvider.get_transaction()` and in that case you should also pass `with_process_status=True`") + super().__init__("The expected account condition was not reached") class TransactionFetchingError(Exception): diff --git a/multiversx_sdk/network_providers/transaction_awaiter.py b/multiversx_sdk/network_providers/transaction_awaiter.py index 3a83e8af..ff713603 100644 --- a/multiversx_sdk/network_providers/transaction_awaiter.py +++ b/multiversx_sdk/network_providers/transaction_awaiter.py @@ -8,8 +8,7 @@ DEFAULT_TRANSACTION_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS, DEFAULT_TRANSACTION_AWAITING_TIMEOUT_IN_MILLISECONDS) from multiversx_sdk.network_providers.errors import ( - ExpectedTransactionStatusNotReached, IsCompletedFieldMissingOnTransaction, - TransactionFetchingError) + ExpectedTransactionStatusNotReachedError, TransactionFetchingError) ONE_SECOND_IN_MILLISECONDS = 1000 @@ -56,9 +55,6 @@ def __init__(self, def await_completed(self, transaction_hash: Union[str, bytes]) -> TransactionOnNetwork: """Waits until the transaction is completely processed.""" def is_completed(tx: TransactionOnNetwork): - if tx.status.is_completed is None: - raise IsCompletedFieldMissingOnTransaction() - return tx.status.is_completed def do_fetch(): @@ -67,7 +63,7 @@ def do_fetch(): return self._await_conditionally( is_satisfied=is_completed, do_fetch=do_fetch, - error=ExpectedTransactionStatusNotReached() + error=ExpectedTransactionStatusNotReachedError() ) def await_on_condition( @@ -80,7 +76,7 @@ def do_fetch(): return self._await_conditionally( is_satisfied=condition, do_fetch=do_fetch, - error=ExpectedTransactionStatusNotReached() + error=ExpectedTransactionStatusNotReachedError() ) def _await_conditionally(self, diff --git a/multiversx_sdk/network_providers/transaction_awaiter_test.py b/multiversx_sdk/network_providers/transaction_awaiter_test.py index b87a93f5..1938b2b8 100644 --- a/multiversx_sdk/network_providers/transaction_awaiter_test.py +++ b/multiversx_sdk/network_providers/transaction_awaiter_test.py @@ -11,10 +11,10 @@ TransactionAwaiter from multiversx_sdk.testutils.mock_network_provider import ( MockNetworkProvider, TimelinePointMarkCompleted, TimelinePointWait) +from multiversx_sdk.testutils.mock_transaction_on_network import \ + get_empty_transaction_on_network from multiversx_sdk.testutils.wallets import load_wallets -from multiversx_sdk.testutils.mock_transaction_on_network import get_empty_transaction_on_network - class TestTransactionAwaiter: provider = MockNetworkProvider() @@ -33,14 +33,14 @@ def test_await_status_executed(self): self.provider.mock_transaction_timeline_by_hash( tx_hash, - [TimelinePointWait(40), TransactionStatus("pending"), TimelinePointWait(40), TransactionStatus("executed"), TimelinePointMarkCompleted()] + [TimelinePointWait(40), TransactionStatus("pending"), TimelinePointWait( + 40), TransactionStatus("executed"), TimelinePointMarkCompleted()] ) tx_from_network = self.watcher.await_completed(tx_hash) assert tx_from_network.status.is_completed @pytest.mark.networkInteraction - @pytest.mark.skip def test_on_network(self): alice = load_wallets()["alice"] proxy = ProxyNetworkProvider("https://devnet-api.multiversx.com") diff --git a/multiversx_sdk/testutils/mock_network_provider.py b/multiversx_sdk/testutils/mock_network_provider.py index f2645430..3c49baae 100644 --- a/multiversx_sdk/testutils/mock_network_provider.py +++ b/multiversx_sdk/testutils/mock_network_provider.py @@ -124,6 +124,21 @@ def mark_tx_as_completed(transaction: TransactionOnNetwork): thread = threading.Thread(target=fn) thread.start() + def mock_account_balance_timeline_by_address(self, address: Address, timeline_points: list[Any]) -> None: + def fn(): + for point in timeline_points: + if isinstance(point, TimelinePointMarkCompleted): + def mark_account_condition_reached(account: AccountOnNetwork): + account.balance = account.balance + create_account_egld_balance(7) + + self.mock_update_account(address, mark_account_condition_reached) + + elif isinstance(point, TimelinePointWait): + time.sleep(point.milliseconds // 1000) + + thread = threading.Thread(target=fn) + thread.start() + def get_account(self, address: Address) -> AccountOnNetwork: account = self.accounts.get(address.to_bech32(), None)