diff --git a/.gitignore b/.gitignore index 95250565..951ab1dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ .tmp .DS_Store node.txt -accounts/ keypairs/ test-ledger/ @@ -157,4 +156,6 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +scratch diff --git a/examples/limit_order_grid.py b/examples/limit_order_grid.py index b98cc244..92bcfff5 100644 --- a/examples/limit_order_grid.py +++ b/examples/limit_order_grid.py @@ -13,7 +13,7 @@ from driftpy.types import * #MarketType, OrderType, OrderParams, PositionDirection, OrderTriggerCondition from driftpy.accounts import get_perp_market_account, get_spot_market_account -from driftpy.math.oracle import get_oracle_data +from driftpy.accounts.oracle import get_oracle_price_data_and_slot from driftpy.math.spot_market import get_signed_token_amount, get_token_amount from driftpy.drift_client import DriftClient from driftpy.drift_user import DriftUser @@ -118,7 +118,7 @@ async def main( drift_acct.program, market_index ) try: - oracle_data = await get_oracle_data(connection, market.amm.oracle) + oracle_data = await get_oracle_price_data_and_slot(connection, market.amm.oracle) current_price = oracle_data.price/PRICE_PRECISION except: current_price = market.amm.historical_oracle_data.last_oracle_price/PRICE_PRECISION @@ -132,7 +132,7 @@ async def main( else: market = await get_spot_market_account( drift_acct.program, market_index) try: - oracle_data = await get_oracle_data(connection, market.oracle) + oracle_data = await get_oracle_price_data_and_slot(connection, market.oracle) current_price = oracle_data.price/PRICE_PRECISION except: current_price = market.historical_oracle_data.last_oracle_price/PRICE_PRECISION diff --git a/src/driftpy/accounts.py b/src/driftpy/accounts.py deleted file mode 100644 index 71f85f4f..00000000 --- a/src/driftpy/accounts.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import cast -from solana.publickey import PublicKey -from anchorpy import Program, ProgramAccount - -from driftpy.types import * -from driftpy.addresses import * - - -async def get_state_account(program: Program) -> State: - state_public_key = get_state_public_key(program.program_id) - response = await program.account["State"].fetch(state_public_key) - return cast(State, response) - - -async def get_if_stake_account( - program: Program, authority: PublicKey, spot_market_index: int -) -> InsuranceFundStake: - if_stake_pk = get_insurance_fund_stake_public_key( - program.program_id, authority, spot_market_index - ) - response = await program.account["InsuranceFundStake"].fetch(if_stake_pk) - return cast(InsuranceFundStake, response) - - -async def get_user_stats_account( - program: Program, - authority: PublicKey, -) -> UserStats: - user_stats_public_key = get_user_stats_account_public_key( - program.program_id, - authority, - ) - response = await program.account["UserStats"].fetch(user_stats_public_key) - return cast(UserStats, response) - - -async def get_user_account( - program: Program, - authority: PublicKey, - subaccount_id: int = 0, -) -> User: - user_public_key = get_user_account_public_key( - program.program_id, authority, subaccount_id - ) - response = await program.account["User"].fetch(user_public_key) - return cast(User, response) - - -async def get_perp_market_account(program: Program, market_index: int) -> PerpMarket: - market_public_key = get_perp_market_public_key(program.program_id, market_index) - response = await program.account["PerpMarket"].fetch(market_public_key) - return cast(PerpMarket, response) - - -async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount]: - return await program.account["PerpMarket"].all() - - -async def get_spot_market_account( - program: Program, spot_market_index: int -) -> SpotMarket: - spot_market_public_key = get_spot_market_public_key( - program.program_id, spot_market_index - ) - response = await program.account["SpotMarket"].fetch(spot_market_public_key) - return cast(SpotMarket, response) - - -async def get_all_spot_market_accounts(program: Program) -> list[ProgramAccount]: - return await program.account["SpotMarket"].all() \ No newline at end of file diff --git a/src/driftpy/accounts/__init__.py b/src/driftpy/accounts/__init__.py new file mode 100644 index 00000000..4c8643a9 --- /dev/null +++ b/src/driftpy/accounts/__init__.py @@ -0,0 +1,2 @@ +from .get_accounts import * +from .types import * \ No newline at end of file diff --git a/src/driftpy/accounts/cache/__init__.py b/src/driftpy/accounts/cache/__init__.py new file mode 100644 index 00000000..f31516fc --- /dev/null +++ b/src/driftpy/accounts/cache/__init__.py @@ -0,0 +1,2 @@ +from .drift_client import * +from .user import * \ No newline at end of file diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py new file mode 100644 index 00000000..c2463922 --- /dev/null +++ b/src/driftpy/accounts/cache/drift_client.py @@ -0,0 +1,78 @@ +from anchorpy import Program +from solana.publickey import PublicKey +from solana.rpc.commitment import Commitment + +from driftpy.accounts import get_state_account_and_slot, get_spot_market_account_and_slot, \ + get_perp_market_account_and_slot +from driftpy.accounts.oracle import get_oracle_price_data_and_slot +from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot +from typing import Optional + +from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State + + +class CachedDriftClientAccountSubscriber(DriftClientAccountSubscriber): + def __init__(self, program: Program, commitment: Commitment = "confirmed"): + self.program = program + self.commitment = commitment + self.cache = None + + async def update_cache(self): + if self.cache is None: + self.cache = {} + + state_and_slot = await get_state_account_and_slot(self.program) + self.cache["state"] = state_and_slot + + oracle_data = {} + + spot_markets = [] + for i in range(state_and_slot.data.number_of_spot_markets): + spot_market_and_slot = await get_spot_market_account_and_slot(self.program, i) + spot_markets.append(spot_market_and_slot) + + oracle_price_data_and_slot = await get_oracle_price_data_and_slot( + self.program.provider.connection, + spot_market_and_slot.data.oracle, + spot_market_and_slot.data.oracle_source + + ) + oracle_data[str(spot_market_and_slot.data.oracle)] = oracle_price_data_and_slot + + self.cache["spot_markets"] = spot_markets + + perp_markets = [] + for i in range(state_and_slot.data.number_of_markets): + perp_market_and_slot = await get_perp_market_account_and_slot(self.program, i) + perp_markets.append(perp_market_and_slot) + + oracle_price_data_and_slot = await get_oracle_price_data_and_slot( + self.program.provider.connection, + perp_market_and_slot.data.amm.oracle, + perp_market_and_slot.data.amm.oracle_source + ) + oracle_data[str(perp_market_and_slot.data.amm.oracle)] = oracle_price_data_and_slot + + self.cache["perp_markets"] = perp_markets + + self.cache["oracle_price_data"] = oracle_data + + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + await self.cache_if_needed() + return self.cache["state"] + + async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + await self.cache_if_needed() + return self.cache["perp_markets"][market_index] + + async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]: + await self.cache_if_needed() + return self.cache["spot_markets"][market_index] + + async def get_oracle_data_and_slot(self, oracle: PublicKey) -> Optional[DataAndSlot[OraclePriceData]]: + await self.cache_if_needed() + return self.cache["oracle_price_data"][str(oracle)] + + async def cache_if_needed(self): + if self.cache is None: + await self.update_cache() diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py new file mode 100644 index 00000000..dd92a93a --- /dev/null +++ b/src/driftpy/accounts/cache/user.py @@ -0,0 +1,29 @@ +from typing import Optional + +from anchorpy import Program +from solana.publickey import PublicKey +from solana.rpc.commitment import Commitment + +from driftpy.accounts import get_user_account_and_slot +from driftpy.accounts import UserAccountSubscriber, DataAndSlot +from driftpy.types import User + + +class CachedUserAccountSubscriber(UserAccountSubscriber): + def __init__(self, user_pubkey: PublicKey, program: Program, commitment: Commitment = "confirmed"): + self.program = program + self.commitment = commitment + self.user_pubkey = user_pubkey + self.user_and_slot = None + + async def update_cache(self): + user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey) + self.user_and_slot = user_and_slot + + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + await self.cache_if_needed() + return self.user_and_slot + + async def cache_if_needed(self): + if self.user_and_slot is None: + await self.update_cache() diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py new file mode 100644 index 00000000..edc1e0a7 --- /dev/null +++ b/src/driftpy/accounts/get_accounts.py @@ -0,0 +1,103 @@ +import base64 +from typing import cast +from solana.publickey import PublicKey +from anchorpy import Program, ProgramAccount +from solana.rpc.commitment import Commitment + +from driftpy.types import * +from driftpy.addresses import * +from .types import DataAndSlot, T + + +async def get_account_data_and_slot(address: PublicKey, program: Program, commitment: Commitment = "processed") -> Optional[ + DataAndSlot[T]]: + account_info = await program.provider.connection.get_account_info( + address, + encoding="base64", + commitment=commitment, + ) + + if not account_info["result"]["value"]: + return None + + slot = account_info["result"]["context"]["slot"] + data = base64.b64decode(account_info["result"]["value"]["data"][0]) + + decoded_data = program.coder.accounts.decode(data) + + return DataAndSlot(slot, decoded_data) + + +async def get_state_account_and_slot(program: Program) -> DataAndSlot[State]: + state_public_key = get_state_public_key(program.program_id) + return await get_account_data_and_slot(state_public_key, program) + + +async def get_state_account(program: Program) -> State: + return (await get_state_account_and_slot(program)).data + + +async def get_if_stake_account( + program: Program, authority: PublicKey, spot_market_index: int +) -> InsuranceFundStake: + if_stake_pk = get_insurance_fund_stake_public_key( + program.program_id, authority, spot_market_index + ) + response = await program.account["InsuranceFundStake"].fetch(if_stake_pk) + return cast(InsuranceFundStake, response) + + +async def get_user_stats_account( + program: Program, + authority: PublicKey, +) -> UserStats: + user_stats_public_key = get_user_stats_account_public_key( + program.program_id, + authority, + ) + response = await program.account["UserStats"].fetch(user_stats_public_key) + return cast(UserStats, response) + +async def get_user_account_and_slot( + program: Program, + user_public_key: PublicKey, +) -> DataAndSlot[User]: + return await get_account_data_and_slot(user_public_key, program) + +async def get_user_account( + program: Program, + user_public_key: PublicKey, +) -> User: + return (await get_user_account_and_slot(program, user_public_key)).data + + +async def get_perp_market_account_and_slot(program: Program, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + perp_market_public_key = get_perp_market_public_key(program.program_id, market_index) + return await get_account_data_and_slot(perp_market_public_key, program) + + +async def get_perp_market_account(program: Program, market_index: int) -> PerpMarket: + return (await get_perp_market_account_and_slot(program, market_index)).data + + +async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount]: + return await program.account["PerpMarket"].all() + + +async def get_spot_market_account_and_slot( + program: Program, spot_market_index: int +) -> DataAndSlot[SpotMarket]: + spot_market_public_key = get_spot_market_public_key( + program.program_id, spot_market_index + ) + return await get_account_data_and_slot(spot_market_public_key, program) + + +async def get_spot_market_account( + program: Program, spot_market_index: int +) -> SpotMarket: + return (await get_spot_market_account_and_slot(program, spot_market_index)).data + + +async def get_all_spot_market_accounts(program: Program) -> list[ProgramAccount]: + return await program.account["SpotMarket"].all() diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py new file mode 100644 index 00000000..127f4f6f --- /dev/null +++ b/src/driftpy/accounts/oracle.py @@ -0,0 +1,66 @@ +from solana.rpc.types import RPCResponse + +from .types import DataAndSlot +from driftpy.constants.numeric_constants import * +from driftpy.types import OracleSource, OraclePriceData + +from solana.publickey import PublicKey +from pythclient.pythaccounts import PythPriceInfo, _ACCOUNT_HEADER_BYTES, EmaType +from solana.rpc.async_api import AsyncClient +import base64 +import struct + +def convert_pyth_price(price, scale=1): + return int(price * PRICE_PRECISION * scale) + +async def get_oracle_price_data_and_slot(connection: AsyncClient, address: PublicKey, oracle_source=OracleSource.PYTH()) -> DataAndSlot[ + OraclePriceData]: + if 'Pyth' in str(oracle_source): + rpc_reponse = await connection.get_account_info(address) + rpc_response_slot = rpc_reponse['result']['context']['slot'] + (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info(rpc_reponse) + + scale = 1 + if '1K' in str(oracle_source): + scale = 1e3 + elif '1M' in str(oracle_source): + scale = 1e6 + + oracle_data = OraclePriceData( + price=convert_pyth_price(pyth_price_info.price, scale), + slot=pyth_price_info.pub_slot, + confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale), + twap=convert_pyth_price(twap, scale), + twap_confidence=convert_pyth_price(twac, scale), + has_sufficient_number_of_datapoints=True, + ) + + return DataAndSlot(data=oracle_data, slot=rpc_response_slot) + elif 'Quote' in str(oracle_source): + return DataAndSlot(data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0) + else: + raise NotImplementedError('Unsupported Oracle Source', str(oracle_source)) + +async def _parse_pyth_price_info(resp: RPCResponse) -> (PythPriceInfo, int, int, int): + value = resp["result"].get("value") + data_base64, data_format = value["data"] + buffer = base64.b64decode(data_base64) + + offset = _ACCOUNT_HEADER_BYTES + _, exponent, _ = struct.unpack_from(" Optional[DataAndSlot[State]]: + pass + + @abstractmethod + async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + pass + + @abstractmethod + async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]: + pass + + @abstractmethod + async def get_oracle_data_and_slot(self, oracle: PublicKey) -> Optional[DataAndSlot[OraclePriceData]]: + pass + +class UserAccountSubscriber: + @abstractmethod + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + pass \ No newline at end of file diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 1e5f93cf..c18d8cd1 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -24,6 +24,9 @@ from typing import Union, Optional, List, Sequence from driftpy.math.positions import is_available, is_spot_position_available +from driftpy.accounts import DriftClientAccountSubscriber +from driftpy.accounts.cache import CachedDriftClientAccountSubscriber + DEFAULT_USER_NAME = "Main Account" DEFAULT_PUBKEY = PublicKey("11111111111111111111111111111111") @@ -33,7 +36,7 @@ class DriftClient: depositing, opening new positions, closing positions, placing orders, etc. """ - def __init__(self, program: Program, signer: Keypair = None, authority: PublicKey = None): + def __init__(self, program: Program, signer: Keypair = None, authority: PublicKey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None): """Initializes the drift client object -- likely want to use the .from_config method instead of this one Args: @@ -57,6 +60,11 @@ def __init__(self, program: Program, signer: Keypair = None, authority: PublicKe self.spot_market_atas = {} self.subaccounts = [0] + if account_subscriber is None: + account_subscriber = CachedDriftClientAccountSubscriber(self.program) + + self.account_subscriber = account_subscriber + @staticmethod def from_config(config: Config, provider: Provider, authority: Keypair = None): """Initializes the drift client object from a Config @@ -89,11 +97,12 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None): drift_client.idl = idl return drift_client + def get_user_account_public_key(self, user_id=0) -> PublicKey: return get_user_account_public_key(self.program_id, self.authority, user_id) async def get_user(self, user_id=0) -> User: - return await get_user_account(self.program, self.authority, user_id) + return await get_user_account(self.program, self.get_user_account_public_key(user_id)) def get_state_public_key(self): return get_state_public_key(self.program_id) @@ -101,6 +110,22 @@ def get_state_public_key(self): def get_user_stats_public_key(self): return get_user_stats_account_public_key(self.program_id, self.authority) + async def get_state(self) -> Optional[State]: + state_and_slot = await self.account_subscriber.get_state_account_and_slot() + return getattr(state_and_slot, 'data', None) + + async def get_perp_market(self, market_index: int) -> Optional[PerpMarket]: + perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot(market_index) + return getattr(perp_market_and_slot, 'data', None) + + async def get_spot_market(self, market_index: int) -> Optional[SpotMarket]: + spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot(market_index) + return getattr(spot_market_and_slot, 'data', None) + + async def get_oracle_price_data(self, oracle: PublicKey) -> Optional[OraclePriceData]: + oracle_price_data_and_slot = await self.account_subscriber.get_oracle_data_and_slot(oracle) + return getattr(oracle_price_data_and_slot, 'data', None) + async def send_ixs( self, ixs: Union[TransactionInstruction, list[TransactionInstruction]], @@ -213,7 +238,8 @@ async def get_remaining_accounts( accounts = [] for pk, id in zip(authority, user_id): - user_account = await get_user_account(self.program, pk, id) + user_public_key = get_user_account_public_key(self.program.program_id, pk, id) + user_account = await get_user_account(self.program, user_public_key) accounts.append(user_account) oracle_map = {} @@ -221,7 +247,7 @@ async def get_remaining_accounts( market_map = {} async def track_market(market_index, is_writable): - perp_market = await get_perp_market_account(self.program, market_index) + perp_market = await self.get_perp_market(market_index) market_map[market_index] = AccountMeta( pubkey=perp_market.pubkey, is_signer=False, @@ -229,8 +255,8 @@ async def track_market(market_index, is_writable): ) if include_oracles: - spot_market = await get_spot_market_account( - self.program, perp_market.quote_spot_market_index + spot_market = await self.get_spot_market( + perp_market.quote_spot_market_index ) if spot_market.oracle != DEFAULT_PUBKEY: oracle_map[str(spot_market.oracle)] = AccountMeta( @@ -241,7 +267,7 @@ async def track_market(market_index, is_writable): ) async def track_spot_market(spot_market_index, is_writable): - spot_market = await get_spot_market_account(self.program, spot_market_index) + spot_market = await self.get_spot_market(spot_market_index) spot_market_map[spot_market_index] = AccountMeta( pubkey=spot_market.pubkey, is_signer=False, @@ -331,7 +357,7 @@ async def get_withdraw_collateral_ix( reduce_only: bool = False, user_id: int = 0, ): - spot_market = await get_spot_market_account(self.program, spot_market_index) + spot_market = await self.get_spot_market(spot_market_index) remaining_accounts = await self.get_remaining_accounts( writable_spot_market_index=spot_market_index, readable_spot_market_index=QUOTE_ASSET_BANK_INDEX, @@ -896,7 +922,7 @@ async def get_user_position( market_index: int, subaccount_id: int = 0, ) -> Optional[PerpPosition]: - user = await get_user_account(self.program, self.authority, subaccount_id) + user = await self.get_user(subaccount_id) found = False for position in user.perp_positions: diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index 6a427ea6..cd062316 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -1,35 +1,22 @@ -from solana.publickey import PublicKey -from typing import Optional +from driftpy.accounts import UserAccountSubscriber +from driftpy.accounts.cache import CachedUserAccountSubscriber from driftpy.drift_client import DriftClient -from driftpy.constants.numeric_constants import * -from driftpy.types import * -from driftpy.accounts import * from driftpy.math.positions import * from driftpy.math.margin import * from driftpy.math.spot_market import * -from driftpy.math.oracle import * - - -def find(l: list, f): - valid_values = [v for v in l if f(v)] - if len(valid_values) == 0: - return None - else: - return valid_values[0] +from driftpy.accounts.oracle import * +from driftpy.types import OraclePriceData class DriftUser: - """This class is the main way to retrieve and inspect data on Drift Protocol.""" + """This class is the main way to retrieve and inspect drift user account data.""" def __init__( self, drift_client: DriftClient, authority: Optional[PublicKey] = None, subaccount_id: int = 0, - use_cache: bool = False, - - - + account_subscriber: Optional[UserAccountSubscriber] = None, ): """Initialize the user object @@ -37,7 +24,6 @@ def __init__( drift_client(DriftClient): required for program_id, idl, things (keypair doesnt matter) authority (Optional[PublicKey], optional): authority to investigate if None will use drift_client.authority subaccount_id (int, optional): subaccount of authority to investigate. Defaults to 0. - use_cache (bool, optional): sdk uses a lot of rpc calls rn - use this flag and .set_cache() to cache accounts and reduce rpc calls. Defaults to False. """ self.drift_client = drift_client self.authority = authority @@ -48,172 +34,32 @@ def __init__( self.oracle_program = drift_client self.connection = self.program.provider.connection self.subaccount_id = subaccount_id - self.use_cache = use_cache - self.cache_is_set = False - - - - # cache all state, perpmarket, oracle, etc. in single cache -- user calls reload - # when they want to update the data? - # get_spot_market - # get_perp_market - # get_user - # if state = cache => get cached_market else get new market - async def set_cache_last(self, CACHE=None): - """sets the cache of the accounts to use to inspect - Args: - CACHE (dict, optional): other existing cache object - if None will pull ƒresh accounts from RPC. Defaults to None. - """ - self.cache_is_set = True + self.user_public_key = get_user_account_public_key(self.program.program_id, self.authority, self.subaccount_id) - if CACHE is not None: - self.CACHE = CACHE - return + if account_subscriber is None: + account_subscriber = CachedUserAccountSubscriber(self.user_public_key, self.program) - self.CACHE = {} - state = await get_state_account(self.program) - self.CACHE["state"] = state + self.account_subscriber = account_subscriber - spot_markets = [] - spot_market_oracle_data = [] - for i in range(state.number_of_spot_markets): - spot_market = await get_spot_market_account(self.program, i) - spot_markets.append(spot_market) - if i == 0: - spot_market_oracle_data.append( - OracleData(PRICE_PRECISION, 0, 1, 1, 0, True) - ) - else: - oracle_data = OracleData( - spot_market.historical_oracle_data.last_oracle_price, - 0, - 1, - 1, - 0, - True, - ) - spot_market_oracle_data.append(oracle_data) - - self.CACHE["spot_markets"] = spot_markets - self.CACHE["spot_market_oracles"] = spot_market_oracle_data - - perp_markets = [] - perp_market_oracle_data = [] - for i in range(state.number_of_markets): - perp_market = await get_perp_market_account(self.program, i) - perp_markets.append(perp_market) - - oracle_data = OracleData( - perp_market.amm.historical_oracle_data.last_oracle_price, - 0, - 1, - 1, - 0, - True, - ) - perp_market_oracle_data.append(oracle_data) + async def get_spot_oracle_data(self, spot_market: SpotMarket) -> Optional[OraclePriceData]: + return await self.drift_client.get_oracle_price_data(spot_market.oracle) - self.CACHE["perp_markets"] = perp_markets - self.CACHE["perp_market_oracles"] = perp_market_oracle_data + async def get_perp_oracle_data(self, perp_market: PerpMarket) -> Optional[OraclePriceData]: + return await self.drift_client.get_oracle_price_data(perp_market.amm.oracle) - user = await get_user_account(self.program, self.authority, self.subaccount_id) - self.CACHE["user"] = user + async def get_state(self) -> State: + return await self.drift_client.get_state() - async def set_cache(self, CACHE=None): - """sets the cache of the accounts to use to inspect + async def get_spot_market(self, market_index: int) -> SpotMarket: + return await self.drift_client.get_spot_market(market_index) - Args: - CACHE (dict, optional): other existing cache object - if None will pull ƒresh accounts from RPC. Defaults to None. - """ - self.cache_is_set = True - - if CACHE is not None: - self.CACHE = CACHE - return - - self.CACHE = {} - state = await get_state_account(self.program) - self.CACHE["state"] = state + async def get_perp_market(self, market_index: int) -> PerpMarket: + return await self.drift_client.get_perp_market(market_index) - spot_markets = [] - spot_market_oracle_data = [] - for i in range(state.number_of_spot_markets): - spot_market = await get_spot_market_account(self.program, i) - spot_markets.append(spot_market) - - if i == 0: - spot_market_oracle_data.append( - OracleData(PRICE_PRECISION, 0, 1, 1, 0, True) - ) - else: - oracle_data = await get_oracle_data(self.connection, spot_market.oracle, spot_market.oracle_source) - spot_market_oracle_data.append(oracle_data) - - self.CACHE["spot_markets"] = spot_markets - self.CACHE["spot_market_oracles"] = spot_market_oracle_data - - perp_markets = [] - perp_market_oracle_data = [] - for i in range(state.number_of_markets): - perp_market = await get_perp_market_account(self.program, i) - perp_markets.append(perp_market) - - oracle_data = await get_oracle_data(self.connection, perp_market.amm.oracle, perp_market.amm.oracle_source) - perp_market_oracle_data.append(oracle_data) - - self.CACHE["perp_markets"] = perp_markets - self.CACHE["perp_market_oracles"] = perp_market_oracle_data - - user = await get_user_account(self.program, self.authority, self.subaccount_id) - self.CACHE["user"] = user - - async def get_spot_oracle_data(self, spot_market: SpotMarket): - if self.use_cache: - assert self.cache_is_set, "must call user.set_cache() first" - return self.CACHE["spot_market_oracles"][spot_market.market_index] - else: - oracle_data = await get_oracle_data(self.connection, spot_market.oracle, spot_market.oracle_source) - return oracle_data - - async def get_perp_oracle_data(self, perp_market: PerpMarket): - if self.use_cache: - assert self.cache_is_set, "must call user.set_cache() first" - return self.CACHE["perp_market_oracles"][perp_market.market_index] - else: - oracle_data = await get_oracle_data(self.connection, perp_market.amm.oracle, perp_market.amm.oracle_source) - return oracle_data - - async def get_state(self): - if self.use_cache: - assert self.cache_is_set, "must call user.set_cache() first" - return self.CACHE["state"] - else: - return await get_state_account(self.program) - - async def get_spot_market(self, i): - if self.use_cache: - assert self.cache_is_set, "must call user.set_cache() first" - return self.CACHE["spot_markets"][i] - else: - return await get_spot_market_account(self.program, i) - - async def get_perp_market(self, i): - if self.use_cache: - assert self.cache_is_set, "must call user.set_cache() first" - return self.CACHE["perp_markets"][i] - else: - return await get_perp_market_account(self.program, i) - - async def get_user(self): - if self.use_cache: - assert self.cache_is_set, "must call user.set_cache() first" - return self.CACHE["user"] - else: - return await get_user_account( - self.program, self.authority, self.subaccount_id - ) + async def get_user(self) -> User: + return (await self.account_subscriber.get_user_account_and_slot()).data async def get_open_orders(self, diff --git a/src/driftpy/math/margin.py b/src/driftpy/math/margin.py index 7354b305..6463b24c 100644 --- a/src/driftpy/math/margin.py +++ b/src/driftpy/math/margin.py @@ -1,11 +1,9 @@ -from driftpy.constants.numeric_constants import * -from driftpy.types import * -from driftpy.accounts import * -from driftpy.math.oracle import OracleData from driftpy.math.spot_market import * from enum import Enum +from driftpy.types import OraclePriceData + def calculate_size_discount_asset_weight( size, @@ -88,7 +86,7 @@ def calculate_size_premium_liability_weight( return max_liability_weight -def calculate_net_user_pnl(perp_market: PerpMarket, oracle_data: OracleData): +def calculate_net_user_pnl(perp_market: PerpMarket, oracle_data: OraclePriceData): net_user_position_value = ( perp_market.amm.base_asset_amount_with_amm * oracle_data.price @@ -103,7 +101,7 @@ def calculate_net_user_pnl(perp_market: PerpMarket, oracle_data: OracleData): def calculate_net_user_pnl_imbalance( - perp_market: PerpMarket, spot_market: SpotMarket, oracle_data: OracleData + perp_market: PerpMarket, spot_market: SpotMarket, oracle_data: OraclePriceData ): user_pnl = calculate_net_user_pnl(perp_market, oracle_data) @@ -120,7 +118,7 @@ def calculate_unrealized_asset_weight( spot_market: SpotMarket, unrealized_pnl: int, margin_category: MarginCategory, - oracle_data: OracleData, + oracle_data: OraclePriceData, ): match margin_category: case MarginCategory.INITIAL: @@ -213,7 +211,7 @@ def calculate_market_margin_ratio( def get_spot_liability_value( token_amount: int, - oracle_data: OracleData, + oracle_data: OraclePriceData, spot_market: SpotMarket, margin_category: MarginCategory, liquidation_buffer: int = None, diff --git a/src/driftpy/math/oracle.py b/src/driftpy/math/oracle.py deleted file mode 100644 index d61c4672..00000000 --- a/src/driftpy/math/oracle.py +++ /dev/null @@ -1,66 +0,0 @@ -from dataclasses import dataclass -from driftpy.constants.numeric_constants import * -from driftpy.types import OracleSource - -from solana.publickey import PublicKey -from pythclient.pythaccounts import PythPriceAccount -from pythclient.solana import ( - SolanaClient, - SolanaPublicKey, -) -from solana.rpc.async_api import AsyncClient - - -def convert_pyth_price(price, scale=1): - return int(price * PRICE_PRECISION * scale) - - -@dataclass -class OracleData: - price: int - slot: int - confidence: int - twap: int - twap_confidence: int - has_sufficient_number_of_datapoints: bool - - -async def get_oracle_data(connection: AsyncClient, address: PublicKey, oracle_source=OracleSource.PYTH()) -> OracleData: - address = str(address) - account_key = SolanaPublicKey(address) - - http_endpoint = connection._provider.endpoint_uri - if "https" in http_endpoint: - ws_endpoint = http_endpoint.replace("https", "wss") - elif "http" in http_endpoint: - ws_endpoint = http_endpoint.replace("http", "wss") - else: - print(http_endpoint) - raise - - solana_client = SolanaClient(endpoint=http_endpoint, ws_endpoint=ws_endpoint) - if 'Pyth' in str(oracle_source): - price: PythPriceAccount = PythPriceAccount(account_key, solana_client) - await price.update() - - # TODO: returns none rn - # (twap, twac) = (price.derivations.get('TWAPVALUE'), price.derivations.get('TWACVALUE')) - (twap, twac) = (0, 0) - scale = 1 - if '1K' in str(oracle_source): - scale = 1e3 - elif '1M' in str(oracle_source): - scale = 1e6 - - oracle_data = OracleData( - price=convert_pyth_price(price.aggregate_price_info.price, scale), - slot=price.last_slot, - confidence=convert_pyth_price(price.aggregate_price_info.confidence_interval, scale), - twap=convert_pyth_price(twap, scale), - twap_confidence=convert_pyth_price(twac, scale), - has_sufficient_number_of_datapoints=True, - ) - - await solana_client.close() - - return oracle_data diff --git a/src/driftpy/math/positions.py b/src/driftpy/math/positions.py index 9753d141..cbe63ea1 100644 --- a/src/driftpy/math/positions.py +++ b/src/driftpy/math/positions.py @@ -1,9 +1,5 @@ -from driftpy.types import PositionDirection, PerpMarket, PerpPosition, SpotPosition -from driftpy.constants.numeric_constants import * -from driftpy.types import * -from driftpy.accounts import * -from driftpy.math.oracle import * from driftpy.math.spot_market import * +from driftpy.types import OraclePriceData def get_worst_case_token_amounts( @@ -28,7 +24,7 @@ def get_worst_case_token_amounts( def calculate_base_asset_value_with_oracle( - perp_position: PerpPosition, oracle_data: OracleData + perp_position: PerpPosition, oracle_data: OraclePriceData ): return ( abs(perp_position.base_asset_amount) @@ -63,7 +59,7 @@ def calculate_position_funding_pnl(market: PerpMarket, perp_position: PerpPositi def calculate_position_pnl_with_oracle( market: PerpMarket, perp_position: PerpPosition, - oracle_data: OracleData, + oracle_data: OraclePriceData, with_funding=False, ): if perp_position.base_asset_amount == 0: diff --git a/src/driftpy/math/spot_market.py b/src/driftpy/math/spot_market.py index 12ee95da..7587454c 100644 --- a/src/driftpy/math/spot_market.py +++ b/src/driftpy/math/spot_market.py @@ -1,7 +1,5 @@ -from driftpy.constants.numeric_constants import * -from driftpy.types import * from driftpy.accounts import * -from driftpy.math.oracle import * +from driftpy.types import OraclePriceData def get_signed_token_amount(amount, balance_type): @@ -30,6 +28,6 @@ def get_token_amount( return balance * cumm_interest / percision_decrease -def get_token_value(amount, spot_decimals, oracle_data: OracleData): +def get_token_value(amount, spot_decimals, oracle_data: OraclePriceData): precision_decrease = 10**spot_decimals return amount * oracle_data.price / precision_decrease diff --git a/src/driftpy/types.py b/src/driftpy/types.py index 22716281..30e704fa 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -836,4 +836,12 @@ class ReferrerName: user: PublicKey user_stats: PublicKey name: list[int] - + +@dataclass +class OraclePriceData: + price: int + slot: int + confidence: int + twap: int + twap_confidence: int + has_sufficient_number_of_datapoints: bool \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index 8d9d2ae2..83de6449 100644 --- a/tests/test.py +++ b/tests/test.py @@ -14,7 +14,7 @@ ) from math import sqrt -from driftpy.drift_user import DriftUser as DriftUser +from driftpy.drift_user import DriftUser from driftpy.drift_client import DriftClient from driftpy.setup.helpers import ( _create_mint, @@ -51,7 +51,7 @@ MARKET_INDEX = 0 workspace = workspace_fixture( - "protocol-v2", build_cmd="anchor build --skip-lint", scope="session" + "protocol-v2", build_cmd="anchor build --skip-build", scope="session" ) @@ -180,8 +180,9 @@ async def test_init_user( drift_client: Admin, ): await drift_client.intialize_user() + user_public_key = get_user_account_public_key(drift_client.program.program_id, drift_client.authority, 0) user: User = await get_user_account( - drift_client.program, drift_client.authority, subaccount_id=0 + drift_client.program, user_public_key ) assert user.authority == drift_client.authority @@ -197,9 +198,7 @@ async def test_usdc_deposit( await drift_client.deposit( USDC_AMOUNT, 0, user_usdc_account.public_key, user_initialized=True ) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) assert ( user_account.spot_positions[0].scaled_balance == USDC_AMOUNT / QUOTE_PRECISION * SPOT_BALANCE_PRECISION @@ -211,9 +210,7 @@ async def test_open_orders( ): drift_user = DriftUser(drift_client) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) assert(len(user_account.orders)==32) assert(user_account.orders[0].market_index == 0) @@ -228,12 +225,14 @@ async def test_open_orders( order_params.user_order_id = 169 ixs = await drift_client.get_place_perp_orders_ix([order_params]) await drift_client.send_ixs(ixs) + await drift_user.account_subscriber.update_cache() open_orders_after = await drift_user.get_open_orders() assert(open_orders_after[0].base_asset_amount == BASE_PRECISION) assert(open_orders_after[0].order_id == 1) assert(open_orders_after[0].user_order_id == 169) await drift_client.cancel_order(1, 0) + await drift_user.account_subscriber.update_cache() open_orders_after2 = await drift_user.get_open_orders() assert(open_orders_after2[0].base_asset_amount == 0) @@ -275,17 +274,13 @@ async def test_add_remove_liquidity( assert state.lp_cooldown_time == 0 await drift_client.add_liquidity(n_shares, 0) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) assert user_account.perp_positions[0].lp_shares == n_shares await drift_client.settle_lp(drift_client.authority, 0) await drift_client.remove_liquidity(n_shares, 0) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) assert user_account.perp_positions[0].lp_shares == 0 @@ -331,18 +326,14 @@ async def test_open_close_position( drift_client.program.provider.connection._commitment = Processed # print(tx) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) assert user_account.perp_positions[0].base_asset_amount == baa assert user_account.perp_positions[0].quote_asset_amount < 0 await drift_client.close_position(0) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) assert user_account.perp_positions[0].base_asset_amount == 0 assert user_account.perp_positions[0].quote_asset_amount < 0 @@ -382,9 +373,7 @@ async def test_liq_perp( drift_client: Admin, usdc_mint: Keypair, workspace: WorkspaceType ): market = await get_perp_market_account(drift_client.program, 0) - user_account = await get_user_account( - drift_client.program, drift_client.authority - ) + user_account = await drift_client.get_user(0) liq, _ = await _airdrop_user(drift_client.program.provider) liq_drift_client = DriftClient(drift_client.program, liq)