Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssetIndicesMixin asset_names setter and tests #193

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion curvesim/pool/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def coin_addresses(self):

@property
def coin_decimals(self):
"""Addresses for the pool coins."""
"""Decimal precisions for the pool coins."""
if hasattr(self, "metadata"):
return self.metadata["coins"]["decimals"]
return []
Expand Down
17 changes: 16 additions & 1 deletion curvesim/pool/sim_interface/asset_indices.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base SimPool implementation for Curve stableswap pools, both regular and meta."""
from abc import abstractmethod

from curvesim.exceptions import CurvesimValueError
from curvesim.exceptions import CurvesimException, CurvesimValueError
from curvesim.utils import cache


Expand All @@ -22,6 +22,16 @@ def asset_names(self):
"""
raise NotImplementedError

@asset_names.setter
@abstractmethod
def asset_names(self, *asset_lists):
"""
Set list of asset names.

Implementations should disallow setting of duplicate names and inconsistent numbers of names.
"""
raise NotImplementedError

@property
@abstractmethod
def _asset_balances(self):
Expand All @@ -31,6 +41,11 @@ def _asset_balances(self):
@property
def asset_balances(self):
"""Return dict mapping asset names to coin balances."""
if len(self.asset_names) != len(self._asset_balances):
raise CurvesimException(
"Number of symbols and number of balances aren't the same."
)

return dict(zip(self.asset_names, self._asset_balances))

@property
Expand Down
51 changes: 47 additions & 4 deletions curvesim/pool/sim_interface/metapool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class SimCurveMetaPool(SimPool, AssetIndicesMixin, CurveMetaPool):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.asset_names = self.coin_names, self.basepool.coin_names

# The rates check has a couple special cases:
# 1. For metapools, we need to use the basepool rates
# instead of the virtual price for the basepool.
Expand All @@ -33,12 +35,53 @@ def asset_names(self):

For metapools, our convention is to place the basepool LP token last.
"""
meta_coin_names = self.coin_names[:-1]
base_coin_names = self.basepool.coin_names
bp_token_name = self.coin_names[-1]
meta_coin_names = self._metapool_names[:-1]
base_coin_names = self._basepool_names
bp_token_name = self._metapool_names[-1]

return [*meta_coin_names, *base_coin_names, bp_token_name]

@asset_names.setter
@override
def asset_names(self, *asset_lists):
"""
Set list of asset names.

Positional args:
----------------

[0]: list of all metapool asset names.
[1]: list of all basepool asset names.
"""
metapool_names = asset_lists[0]
basepool_names = asset_lists[1]

if len(metapool_names) != len(set(metapool_names)) or len(
basepool_names
) != len(set(basepool_names)):
raise SimPoolError(
"SimCurveMetaPool must have unique asset names for metapool and basepool, separately."
)

if hasattr(self, "asset_names") and (
len(self._metapool_names) != len(metapool_names)
or len(self._basepool_names) != len(basepool_names)
):
raise SimPoolError(
"SimCurveMetaPool must have a consistent number of metapool asset names and \
basepool asset names, separately."
)

if not hasattr(self, "asset_names"):
self._metapool_names = [str()] * len(metapool_names)
self._basepool_names = [str()] * len(basepool_names)

for i in range(len(metapool_names)):
self._metapool_names[i] = metapool_names[i]

for i in range(len(basepool_names)):
self._basepool_names[i] = basepool_names[i]

@property
@override
def _asset_balances(self):
Expand Down Expand Up @@ -131,7 +174,7 @@ def get_in_amount(self, coin_in, coin_out, out_balance_perc):
@override
@cache
def assets(self):
symbols = self.coin_names[:-1] + self.basepool.coin_names
symbols = self._metapool_names[:-1] + self._basepool_names
addresses = self.coin_addresses[:-1] + self.basepool.coin_addresses

return SimAssets(symbols, addresses, self.chain)
31 changes: 29 additions & 2 deletions curvesim/pool/sim_interface/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class SimCurvePool(SimPool, AssetIndicesMixin, CurvePool):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.asset_names = self.coin_names

rates = self.rates # pylint: disable=no-member
for r in rates:
if r != 10**18:
Expand All @@ -21,7 +23,32 @@ def __init__(self, *args, **kwargs):
@cache
def asset_names(self):
"""Return list of asset names."""
return self.coin_names
return self._asset_names

@asset_names.setter
@override
def asset_names(self, *asset_lists):
"""
Set list of asset names.

Positional args:
----------------

[0]: list of all pool asset names.
"""
asset_names = asset_lists[0]

if len(asset_names) != len(set(asset_names)):
raise SimPoolError("SimPool must have unique asset names.")

if hasattr(self, "asset_names") and len(self.asset_names) != len(asset_names):
raise SimPoolError("SimPool must have a consistent number of asset names.")

if not hasattr(self, "asset_names"):
self._asset_names = [str()] * len(asset_names)

for i in range(len(asset_names)):
self._asset_names[i] = asset_names[i]

@property
@override
Expand Down Expand Up @@ -56,4 +83,4 @@ def get_in_amount(self, coin_in, coin_out, out_balance_perc):
@override
@cache
def assets(self):
return SimAssets(self.coin_names, self.coin_addresses, self.chain)
return SimAssets(self.asset_names, self.coin_addresses, self.chain)
116 changes: 95 additions & 21 deletions test/unit/test_coin_indices_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools

from curvesim.pool.sim_interface.asset_indices import AssetIndicesMixin
from curvesim.exceptions import CurvesimValueError, SimPoolError
from curvesim.utils import override

# pylint: disable=redefined-outer-name
Expand All @@ -14,10 +15,30 @@ class FakeSimPool(AssetIndicesMixin):
for testing purposes.
"""

def __init__(self):
self.asset_names = ["SYM_0", "SYM_1", "SYM_2"]

@property
@override
def asset_names(self):
return ["SYM_0", "SYM_1", "SYM_2"]
return self._asset_names

@asset_names.setter
@override
def asset_names(self, *asset_lists):
asset_names = asset_lists[0]

if len(asset_names) != len(set(asset_names)):
raise SimPoolError("SimPool must have unique asset names.")

if hasattr(self, "asset_names") and len(self.asset_names) != len(asset_names):
raise SimPoolError("SimPool must have a consistent number of asset names.")

if not hasattr(self, "asset_names"):
self._asset_names = [str()] * len(asset_names)

for i in range(len(asset_names)):
self._asset_names[i] = asset_names[i]

@property
@override
Expand All @@ -31,38 +52,91 @@ def sim_pool():
return FakeSimPool()


def indices(sim_pool, *assets):
"""Returns the indices of all symbols/indices in assets"""
indices = []
symbols = sim_pool.asset_names
for ID in assets:
if isinstance(ID, str):
ID = symbols.index(ID)
indices.append(ID)

return indices


def duplicates(sim_pool, *assets):
"""Determines whether assets contains duplicate symbols/indices"""
asset_indices = indices(sim_pool, *assets)

return len(asset_indices) != len(set(asset_indices))


def test_asset_indices(sim_pool):
"""Test index conversion and getting"""
assert sim_pool.asset_indices == {"SYM_0": 0, "SYM_1": 1, "SYM_2": 2}


def test_asset_balances(sim_pool):
"""Test mapping symbols to balances"""
assert sim_pool.asset_balances == {"SYM_0": 100, "SYM_1": 200, "SYM_2": 300}


def test_get_asset_indices(sim_pool):
"""Test getting index from symbol or index itself"""
# Example calling by symbol
result = sim_pool.get_asset_indices("SYM_2", "SYM_0")
assert result == [2, 0]

names = sim_pool.asset_names
name_sets = [
list(itertools.permutations(names, r=i)) for i in range(1, len(names) + 1)
]
for lst in name_sets:
for name_set in lst:
name_set = list(name_set)
result = sim_pool.get_asset_indices(*name_set)
assert result == [names.index(symbol) for symbol in name_set]

# Example calling by index
result = sim_pool.get_asset_indices(1, 2)
assert result == [1, 2]

# Example calling by symbol and index
result = sim_pool.get_asset_indices("SYM_0", 2, "SYM_1")
assert result == [0, 2, 1]

# Examples calling by a symbol that doesn't exist
try:
assets = ["SYM_0", 1, "SYM_3"]
result = sim_pool.get_asset_indices(*assets)
except Exception as err:
assert isinstance(err, KeyError)

try:
assets = [2, 3, 0]
result = sim_pool.get_asset_indices(*assets)
except Exception as err:
assert isinstance(err, KeyError)

# Examples calling by symbol and index, with occasional duplicates
symbols = sim_pool.asset_names
indices = sim_pool.asset_indices.values()
index_sets = [
list(itertools.permutations(indices, r=i)) for i in range(1, len(indices) + 1)
symbols_and_indices = symbols + indices
assets = [
list(itertools.permutations(symbols_and_indices, r=i))
for i in range(1, len(symbols_and_indices) + 1)
]
for lst in index_sets:
for index_set in lst:
index_set = list(index_set)
result = sim_pool.get_asset_indices(*index_set)
assert result == index_set
for lst in assets:
for asset_set in lst:
try:
result = sim_pool.get_asset_indices(*asset_set)
assert result == indices(sim_pool, *asset_set)
except Exception as err:
assert duplicates(sim_pool, *asset_set)
assert isinstance(err, CurvesimValueError)

assert sim_pool.asset_indices == {"SYM_0": 0, "SYM_1": 1, "SYM_2": 2}

# how we would test the asset_names setter
# for breaking tests, check that the appropriate exception is raised (may need to use multiple functions in a
# specific order to get the right error)

def test_asset_balances(sim_pool):
assert sim_pool.asset_balances == {"SYM_0": 100, "SYM_1": 200, "SYM_2": 300}
# post-init set with duplicate symbols

# initial set that's longer/shorter than _asset_balances
# delete self.asset_names -> "initial" set immediately after

# change number of symbols after initial set (no matter if valid length or invalid length)
# assert isinstance(err, SimPoolError)

# regular set to different symbols
# check that get_asset_indices returns something different from before