Skip to content

Commit

Permalink
implement account awaiter
Browse files Browse the repository at this point in the history
  • Loading branch information
popenta committed Nov 18, 2024
1 parent 3037984 commit d976be2
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 14 deletions.
96 changes: 96 additions & 0 deletions multiversx_sdk/network_providers/account_awaiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging
import time
from typing import Callable, Optional, Protocol, Union

from multiversx_sdk.core.address import Address
from multiversx_sdk.network_providers.constants import (
DEFAULT_ACCOUNT_AWAITING_PATIENCE_IN_MILLISECONDS,
DEFAULT_ACCOUNT_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS,
DEFAULT_ACCOUNT_AWAITING_TIMEOUT_IN_MILLISECONDS)
from multiversx_sdk.network_providers.errors import \
ExpectedAccountConditionNotReachedError
from multiversx_sdk.network_providers.resources import AccountOnNetwork

ONE_SECOND_IN_MILLISECONDS = 1000

logger = logging.getLogger("account_awaiter")


class IAccountFetcher(Protocol):
def get_account(self, address: Address) -> AccountOnNetwork:
...


class AccountAwaiter:
"""AccountAwaiter allows one to await until a specific event occurs on a given address."""

def __init__(self,
fetcher: IAccountFetcher,
polling_interval_in_milliseconds: Optional[int] = None,
timeout_interval_in_milliseconds: Optional[int] = None,
patience_time_in_milliseconds: Optional[int] = None) -> None:
"""
Args:
fetcher (IAccountFetcher): Used to fetch the account of the network.
polling_interval_in_milliseconds (Optional[int]): The polling interval, in milliseconds.
timeout_interval_in_milliseconds (Optional[int]): The timeout, in milliseconds.
patience_time_in_milliseconds (Optional[int]): The patience, an extra time (in milliseconds) to wait, after the account has reached its desired condition.
"""
self.fetcher = fetcher

if polling_interval_in_milliseconds is None:
self.polling_interval_in_milliseconds = DEFAULT_ACCOUNT_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS
else:
self.polling_interval_in_milliseconds = polling_interval_in_milliseconds

if timeout_interval_in_milliseconds is None:
self.timeout_interval_in_milliseconds = DEFAULT_ACCOUNT_AWAITING_TIMEOUT_IN_MILLISECONDS
else:
self.timeout_interval_in_milliseconds = timeout_interval_in_milliseconds

if patience_time_in_milliseconds is None:
self.patience_time_in_milliseconds = DEFAULT_ACCOUNT_AWAITING_PATIENCE_IN_MILLISECONDS
else:
self.patience_time_in_milliseconds = patience_time_in_milliseconds

def await_on_condition(self, address: Address, condition: Callable[[AccountOnNetwork], bool]) -> AccountOnNetwork:
"""Waits until the condition is satisfied."""
def do_fetch():
return self.fetcher.get_account(address)

return self._await_conditionally(
is_satisfied=condition,
do_fetch=do_fetch,
error=ExpectedAccountConditionNotReachedError()
)

def _await_conditionally(self,
is_satisfied: Callable[[AccountOnNetwork], bool],
do_fetch: Callable[[], AccountOnNetwork],
error: Exception) -> AccountOnNetwork:
is_condition_satisfied = False
fetched_data: Union[AccountOnNetwork, None] = None
max_number_of_retries = self.timeout_interval_in_milliseconds // self.polling_interval_in_milliseconds

number_of_retries = 0
while number_of_retries < max_number_of_retries:
try:
fetched_data = do_fetch()
is_condition_satisfied = is_satisfied(fetched_data)

if is_condition_satisfied:
break
except Exception as ex:
raise ex

number_of_retries += 1
time.sleep(self.polling_interval_in_milliseconds / ONE_SECOND_IN_MILLISECONDS)

if fetched_data is None or not is_condition_satisfied:
raise error

if self.patience_time_in_milliseconds:
time.sleep(self.patience_time_in_milliseconds / ONE_SECOND_IN_MILLISECONDS)
return do_fetch()

return fetched_data
72 changes: 72 additions & 0 deletions multiversx_sdk/network_providers/account_awaiter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest

from multiversx_sdk.core.address import Address
from multiversx_sdk.core.transaction import Transaction
from multiversx_sdk.core.transaction_computer import TransactionComputer
from multiversx_sdk.network_providers.account_awaiter import AccountAwaiter
from multiversx_sdk.network_providers.api_network_provider import \
ApiNetworkProvider
from multiversx_sdk.network_providers.resources import AccountOnNetwork
from multiversx_sdk.testutils.mock_network_provider import (
MockNetworkProvider, TimelinePointMarkCompleted, TimelinePointWait)
from multiversx_sdk.testutils.utils import create_account_egld_balance
from multiversx_sdk.testutils.wallets import load_wallets


class TestTransactionAwaiter:
provider = MockNetworkProvider()
watcher = AccountAwaiter(
fetcher=provider,
polling_interval_in_milliseconds=42,
timeout_interval_in_milliseconds=42 * 42,
patience_time_in_milliseconds=0
)

@pytest.mark.only
def test_await_on_balance_increase(self):
alice = Address.new_from_bech32("erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th")
initial_balance = self.provider.get_account(alice).balance

# adds 7 EGLD to the account balance
self.provider.mock_account_balance_timeline_by_address(
alice,
[TimelinePointWait(40), TimelinePointWait(40), TimelinePointWait(45), TimelinePointMarkCompleted()]
)

def condition(account: AccountOnNetwork):
return account.balance == initial_balance + create_account_egld_balance(7)

account = self.watcher.await_on_condition(alice, condition)
assert account.balance == create_account_egld_balance(1007)

@pytest.mark.networkInteraction
def test_on_network(self):
alice = load_wallets()["alice"]
alice_address = Address.new_from_bech32(alice.label)
frank = Address.new_from_bech32("erd1kdl46yctawygtwg2k462307dmz2v55c605737dp3zkxh04sct7asqylhyv")

api = ApiNetworkProvider("https://devnet-api.multiversx.com")
watcher = AccountAwaiter(fetcher=api)
tx_computer = TransactionComputer()
value = 100_000

transaction = Transaction(
sender=alice_address,
receiver=frank,
gas_limit=50000,
chain_id="D",
value=value
)
transaction.nonce = api.get_account(alice_address).nonce
transaction.signature = alice.secret_key.sign(tx_computer.compute_bytes_for_signing(transaction))

initial_balance = api.get_account(frank).balance
print("initial:", initial_balance)

def condition(account: AccountOnNetwork):
return account.balance == initial_balance + value

api.send_transaction(transaction)

account_on_network = watcher.await_on_condition(frank, condition)
assert account_on_network.balance == initial_balance + value
4 changes: 4 additions & 0 deletions multiversx_sdk/network_providers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
DEFAULT_TRANSACTION_AWAITING_TIMEOUT_IN_MILLISECONDS = 15 * DEFAULT_TRANSACTION_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS
DEFAULT_TRANSACTION_AWAITING_PATIENCE_IN_MILLISECONDS = 3000

DEFAULT_ACCOUNT_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS = 6000
DEFAULT_ACCOUNT_AWAITING_TIMEOUT_IN_MILLISECONDS = 15 * DEFAULT_TRANSACTION_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS
DEFAULT_ACCOUNT_AWAITING_PATIENCE_IN_MILLISECONDS = 0

BASE_USER_AGENT = "multiversx-sdk-py"
UNKNOWN_CLIENT_NAME = "unknown"
6 changes: 3 additions & 3 deletions multiversx_sdk/network_providers/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ def __init__(self, url: str, data: Any):
self.data = data


class ExpectedTransactionStatusNotReached(Exception):
class ExpectedTransactionStatusNotReachedError(Exception):
def __init__(self) -> None:
super().__init__("The expected transaction status was not reached")


class IsCompletedFieldMissingOnTransaction(Exception):
class ExpectedAccountConditionNotReachedError(Exception):
def __init__(self) -> None:
super().__init__("The transaction awaiter requires the `is_completed` property to be defined on the transaction object. Perhaps you've used `ProxyNetworkProvider.get_transaction()` and in that case you should also pass `with_process_status=True`")
super().__init__("The expected account condition was not reached")


class TransactionFetchingError(Exception):
Expand Down
10 changes: 3 additions & 7 deletions multiversx_sdk/network_providers/transaction_awaiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
DEFAULT_TRANSACTION_AWAITING_POLLING_TIMEOUT_IN_MILLISECONDS,
DEFAULT_TRANSACTION_AWAITING_TIMEOUT_IN_MILLISECONDS)
from multiversx_sdk.network_providers.errors import (
ExpectedTransactionStatusNotReached, IsCompletedFieldMissingOnTransaction,
TransactionFetchingError)
ExpectedTransactionStatusNotReachedError, TransactionFetchingError)

ONE_SECOND_IN_MILLISECONDS = 1000

Expand Down Expand Up @@ -56,9 +55,6 @@ def __init__(self,
def await_completed(self, transaction_hash: Union[str, bytes]) -> TransactionOnNetwork:
"""Waits until the transaction is completely processed."""
def is_completed(tx: TransactionOnNetwork):
if tx.status.is_completed is None:
raise IsCompletedFieldMissingOnTransaction()

return tx.status.is_completed

def do_fetch():
Expand All @@ -67,7 +63,7 @@ def do_fetch():
return self._await_conditionally(
is_satisfied=is_completed,
do_fetch=do_fetch,
error=ExpectedTransactionStatusNotReached()
error=ExpectedTransactionStatusNotReachedError()
)

def await_on_condition(
Expand All @@ -80,7 +76,7 @@ def do_fetch():
return self._await_conditionally(
is_satisfied=condition,
do_fetch=do_fetch,
error=ExpectedTransactionStatusNotReached()
error=ExpectedTransactionStatusNotReachedError()
)

def _await_conditionally(self,
Expand Down
8 changes: 4 additions & 4 deletions multiversx_sdk/network_providers/transaction_awaiter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
TransactionAwaiter
from multiversx_sdk.testutils.mock_network_provider import (
MockNetworkProvider, TimelinePointMarkCompleted, TimelinePointWait)
from multiversx_sdk.testutils.mock_transaction_on_network import \
get_empty_transaction_on_network
from multiversx_sdk.testutils.wallets import load_wallets

from multiversx_sdk.testutils.mock_transaction_on_network import get_empty_transaction_on_network


class TestTransactionAwaiter:
provider = MockNetworkProvider()
Expand All @@ -33,14 +33,14 @@ def test_await_status_executed(self):

self.provider.mock_transaction_timeline_by_hash(
tx_hash,
[TimelinePointWait(40), TransactionStatus("pending"), TimelinePointWait(40), TransactionStatus("executed"), TimelinePointMarkCompleted()]
[TimelinePointWait(40), TransactionStatus("pending"), TimelinePointWait(
40), TransactionStatus("executed"), TimelinePointMarkCompleted()]
)
tx_from_network = self.watcher.await_completed(tx_hash)

assert tx_from_network.status.is_completed

@pytest.mark.networkInteraction
@pytest.mark.skip
def test_on_network(self):
alice = load_wallets()["alice"]
proxy = ProxyNetworkProvider("https://devnet-api.multiversx.com")
Expand Down
15 changes: 15 additions & 0 deletions multiversx_sdk/testutils/mock_network_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ def mark_tx_as_completed(transaction: TransactionOnNetwork):
thread = threading.Thread(target=fn)
thread.start()

def mock_account_balance_timeline_by_address(self, address: Address, timeline_points: list[Any]) -> None:
def fn():
for point in timeline_points:
if isinstance(point, TimelinePointMarkCompleted):
def mark_account_condition_reached(account: AccountOnNetwork):
account.balance = account.balance + create_account_egld_balance(7)

self.mock_update_account(address, mark_account_condition_reached)

elif isinstance(point, TimelinePointWait):
time.sleep(point.milliseconds // 1000)

thread = threading.Thread(target=fn)
thread.start()

def get_account(self, address: Address) -> AccountOnNetwork:
account = self.accounts.get(address.to_bech32(), None)

Expand Down

0 comments on commit d976be2

Please sign in to comment.