Skip to content

Commit

Permalink
refactor: isolate get_test_context to reduce circular imports
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx committed Aug 6, 2024
1 parent 2c39c79 commit 913428a
Show file tree
Hide file tree
Showing 27 changed files with 111 additions and 118 deletions.
3 changes: 2 additions & 1 deletion src/algopy_testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from algopy_testing import arc4, gtxn, itxn
from algopy_testing._arc4_factory import ARC4Factory
from algopy_testing._context_storage import algopy_testing_context, get_test_context
from algopy_testing._itxn_loader import ITxnGroupLoader, ITxnLoader
from algopy_testing.context import AlgopyTestContext, algopy_testing_context, get_test_context
from algopy_testing.context import AlgopyTestContext
from algopy_testing.decorators.subroutine import subroutine
from algopy_testing.enums import OnCompleteAction, TransactionType
from algopy_testing.models import (
Expand Down
43 changes: 43 additions & 0 deletions src/algopy_testing/_context_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import typing
from contextlib import contextmanager
from contextvars import ContextVar

if typing.TYPE_CHECKING:
from collections.abc import Generator

import algopy

from algopy_testing import AlgopyTestContext

_var: ContextVar[AlgopyTestContext] = ContextVar("_var")


def get_test_context() -> AlgopyTestContext:
try:
result = _var.get()
except LookupError:
raise ValueError(
"Test context is not initialized! Use `with algopy_testing_context()` to "
"access the context manager."
) from None
return result


@contextmanager
def algopy_testing_context(
*,
default_creator: algopy.Account | None = None,
) -> Generator[AlgopyTestContext, None, None]:
from algopy_testing.context import AlgopyTestContext

token = _var.set(
AlgopyTestContext(
default_creator=default_creator,
)
)
try:
yield _var.get()
finally:
_var.reset(token)
3 changes: 1 addition & 2 deletions src/algopy_testing/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import algosdk
from Cryptodome.Hash import SHA512

from algopy_testing._context_storage import get_test_context
from algopy_testing.constants import (
ARC4_RETURN_PREFIX,
BITS_IN_BYTE,
Expand Down Expand Up @@ -1204,8 +1205,6 @@ def emit_swapped(self, a: arc4.UInt64, b: arc4.UInt64) -> None:
""" # noqa: E501
import algopy

from algopy_testing.context import get_test_context

context = get_test_context()
active_txn = context.get_active_transaction()

Expand Down
41 changes: 4 additions & 37 deletions src/algopy_testing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import typing
from collections import ChainMap, defaultdict
from contextlib import contextmanager
from contextvars import ContextVar

# Define the union type
from typing import Any, Unpack
from typing import Unpack

import algosdk

Expand Down Expand Up @@ -80,7 +77,7 @@ def __init__(
self,
*,
default_creator: algopy.Account | None = None,
template_vars: dict[str, Any] | None = None,
template_vars: dict[str, typing.Any] | None = None,
) -> None:
import algopy

Expand All @@ -105,7 +102,7 @@ def __init__(
self._inner_transaction_groups: list[Sequence[InnerTransactionResultType]] = []
self._constructing_inner_transaction_group: list[InnerTransactionResultType] = []
self._constructing_inner_transaction: InnerTransactionResultType | None = None
self._template_vars: dict[str, Any] = template_vars or {}
self._template_vars: dict[str, typing.Any] = template_vars or {}
self._blocks: dict[int, dict[str, int]] = {}
self._boxes: dict[bytes, bytes] = {}
self._lsigs: dict[algopy.LogicSig, Callable[[], algopy.UInt64 | bool]] = {}
Expand Down Expand Up @@ -225,7 +222,7 @@ def set_active_contract(self, contract: algopy.Contract | algopy.ARC4Contract) -
self._global_fields["current_application_address"] = app.address
self._global_fields["current_application_id"] = app

def set_template_var(self, name: str, value: Any) -> None:
def set_template_var(self, name: str, value: typing.Any) -> None:
"""Set a template variable for the current context.
:param name: The name of the template variable.
Expand Down Expand Up @@ -1159,36 +1156,6 @@ def reset(self) -> None:
self._app_id = iter(range(1, 2**64))


_var: ContextVar[AlgopyTestContext] = ContextVar("_var")


def get_test_context() -> AlgopyTestContext:
try:
result = _var.get()
except LookupError:
raise ValueError(
"Test context is not initialized! Use `with algopy_testing_context()` to "
"access the context manager."
) from None
return result


@contextmanager
def algopy_testing_context(
*,
default_creator: algopy.Account | None = None,
) -> Generator[AlgopyTestContext, None, None]:
token = _var.set(
AlgopyTestContext(
default_creator=default_creator,
)
)
try:
yield _var.get()
finally:
_var.reset(token)


def _assert_address_is_valid(address: str) -> None:
assert algosdk.encoding.is_valid_address(address), "Invalid Algorand address supplied!"

Expand Down
3 changes: 2 additions & 1 deletion src/algopy_testing/decorators/abimethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import algosdk

import algopy_testing
from algopy_testing._context_storage import get_test_context
from algopy_testing.constants import ALWAYS_APPROVE_TEAL_PROGRAM
from algopy_testing.models.txn_fields import ApplicationCallFields
from algopy_testing.utils import (
Expand Down Expand Up @@ -81,7 +82,7 @@ def abimethod( # noqa: PLR0913

@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
context = algopy_testing.get_test_context()
context = get_test_context()
if context._active_transaction_index is not None:
return fn(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/algopy_testing/decorators/baremethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import functools
import typing

from algopy_testing._context_storage import get_test_context

if typing.TYPE_CHECKING:
import algopy

Expand Down Expand Up @@ -67,8 +69,6 @@ def baremethod(
def decorator(fn: typing.Callable[_P, _R]) -> typing.Callable[_P, _R]:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
from algopy_testing import get_test_context

context = get_test_context()
if context._active_transaction_index is not None:
return fn(*args, **kwargs)
Expand Down
5 changes: 1 addition & 4 deletions src/algopy_testing/gtxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing

from algopy_testing._context_storage import get_test_context
from algopy_testing.enums import TransactionType
from algopy_testing.models.txn_fields import TransactionFieldsBase

Expand Down Expand Up @@ -40,8 +41,6 @@ def new(cls) -> typing.Self:

@property
def key_txn(self) -> TransactionBase:
from algopy_testing.context import get_test_context

if self._is_context_mapped:
return self

Expand All @@ -56,8 +55,6 @@ def key_txn(self) -> TransactionBase:

@property
def fields(self) -> dict[str, object]:
from algopy_testing.context import get_test_context

context = get_test_context()
try:
return context._gtxns[self.key_txn]
Expand Down
16 changes: 7 additions & 9 deletions src/algopy_testing/itxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import algosdk

import algopy_testing
from algopy_testing.context import get_test_context
from algopy_testing._context_storage import get_test_context
from algopy_testing.enums import TransactionType
from algopy_testing.models import Account
from algopy_testing.models import Account, Asset
from algopy_testing.models.txn_fields import (
TransactionFieldsBase,
get_txn_defaults,
Expand Down Expand Up @@ -50,7 +49,7 @@


class _BaseInnerTransactionResult(TransactionFieldsBase):
txn_type: algopy_testing.TransactionType = TransactionType.Payment
txn_type: TransactionType = TransactionType.Payment

def __init__(self, **fields: typing.Any):
self._fields = fields
Expand All @@ -61,7 +60,8 @@ def fields(self) -> dict[str, object]:

@property
def _logs(self) -> list[bytes]:
context = algopy_testing.get_test_context()

context = get_test_context()
try:
return context._application_logs[int(self.app_id.id)]
except KeyError:
Expand Down Expand Up @@ -134,7 +134,7 @@ def set(self, **fields: typing.Any) -> None:
_narrow_covariant_types(fields)

def submit(self) -> typing.Any:
context = algopy_testing.get_test_context()
context = get_test_context()
result = _get_itxn_result(self)
context._append_inner_transaction_group([result])
return result
Expand Down Expand Up @@ -178,8 +178,6 @@ def submit_txns(
:returns: A tuple of the resulting inner transactions
"""
from algopy_testing import get_test_context

context = get_test_context()

if len(transactions) > algosdk.constants.TX_GROUP_LIMIT:
Expand Down Expand Up @@ -210,7 +208,7 @@ def _on_keyreg(fields: dict[str, typing.Any]) -> dict[str, typing.Any]:

def _on_asset_config(fields: dict[str, typing.Any]) -> dict[str, typing.Any]:
# if it is a txn to create an asset then ensure this is reflected in the context
if fields.get("config_asset") == algopy_testing.Asset():
if fields.get("config_asset") == Asset():
context = get_test_context()
# TODO: refine
created_asset = context.any_asset(
Expand Down
16 changes: 9 additions & 7 deletions src/algopy_testing/models/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import algosdk

import algopy_testing
from algopy_testing.primitives.bytes import Bytes
from algopy_testing._context_storage import get_test_context
from algopy_testing.primitives import Bytes, UInt64
from algopy_testing.utils import as_bytes

if typing.TYPE_CHECKING:
Expand All @@ -32,12 +32,12 @@ class AccountFields(typing.TypedDict, total=False):


def get_empty_account() -> AccountContextData:
zero = algopy_testing.UInt64(0)
zero = UInt64()
return AccountContextData(
fields={
"balance": zero,
"min_balance": zero,
"auth_address": algopy_testing.Account(),
"auth_address": Account(),
"total_num_uint": zero,
"total_num_byte_slice": zero,
"total_extra_app_pages": zero,
Expand Down Expand Up @@ -82,13 +82,15 @@ def __init__(self, value: str | Bytes = algosdk.constants.ZERO_ADDRESS, /):

@property
def data(self) -> AccountContextData:
context = algopy_testing.get_test_context()
context = get_test_context()
return context._account_data[self.public_key]

def is_opted_in(self, asset_or_app: algopy.Asset | algopy.Application, /) -> bool:
if isinstance(asset_or_app, algopy_testing.Asset):
from algopy_testing.models import Application, Asset

if isinstance(asset_or_app, Asset):
return asset_or_app.id in self.data.opted_asset_balances
elif isinstance(asset_or_app, algopy_testing.Application):
elif isinstance(asset_or_app, Application):
return asset_or_app.id in self.data.opted_apps

raise TypeError(
Expand Down
7 changes: 4 additions & 3 deletions src/algopy_testing/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import inspect
import typing

import algopy_testing
from algopy_testing._context_storage import get_test_context
from algopy_testing.primitives import UInt64
from algopy_testing.utils import as_int64

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -31,11 +32,11 @@ def __init__(self, application_id: algopy.UInt64 | int = 0, /):

@property
def id(self) -> algopy.UInt64:
return algopy_testing.UInt64(self._id)
return UInt64(self._id)

@property
def fields(self) -> ApplicationFields:
context = algopy_testing.get_test_context()
context = get_test_context()
try:
return context._application_data[self._id]
except KeyError:
Expand Down
6 changes: 2 additions & 4 deletions src/algopy_testing/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypedDict, TypeVar

from algopy_testing._context_storage import get_test_context

if TYPE_CHECKING:
import algopy

Expand Down Expand Up @@ -35,8 +37,6 @@ def __init__(self, asset_id: algopy.UInt64 | int = 0):
self.id = asset_id if isinstance(asset_id, UInt64) else UInt64(asset_id)

def balance(self, account: algopy.Account) -> algopy.UInt64:
from algopy_testing.context import get_test_context

context = get_test_context()
if account not in context._account_data:
raise ValueError(
Expand Down Expand Up @@ -65,8 +65,6 @@ def frozen(self, _account: algopy.Account) -> bool:
)

def __getattr__(self, name: str) -> object:
from algopy_testing.context import get_test_context

context = get_test_context()
if int(self.id) not in context._asset_data:
# check if its not 0 (which means its not
Expand Down
2 changes: 1 addition & 1 deletion src/algopy_testing/models/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import typing

import algopy_testing
from algopy_testing._context_storage import get_test_context
from algopy_testing.constants import MAX_BOX_SIZE
from algopy_testing.context import get_test_context
from algopy_testing.utils import as_bytes, as_string

_TKey = typing.TypeVar("_TKey")
Expand Down
5 changes: 3 additions & 2 deletions src/algopy_testing/models/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, final

import algopy_testing
from algopy_testing._context_storage import get_test_context

if TYPE_CHECKING:
import algopy
Expand All @@ -27,7 +28,7 @@ class _StateTotals:

class _ContractMeta(type):
def __call__(cls, *args: Any, **kwargs: dict[str, Any]) -> object:
context = algopy_testing.get_test_context()
context = get_test_context()
instance = super().__call__(*args, **kwargs)

if context and isinstance(instance, Contract):
Expand Down Expand Up @@ -136,7 +137,7 @@ def __getattribute__(self, name: str) -> Any:
if name in ("approval_program", "clear_state_program"):

def wrapper(*args: Any, **kwargs: dict[str, Any]) -> Any:
context = algopy_testing.get_test_context()
context = get_test_context()
# TODO: this should also set up the current txn like abimethod does
context.set_active_contract(self)
try:
Expand Down
Loading

0 comments on commit 913428a

Please sign in to comment.