Skip to content

Commit

Permalink
account subscription interface
Browse files Browse the repository at this point in the history
add account subscription interface
  • Loading branch information
crispheaney authored Nov 16, 2023
2 parents 56b91eb + 1da7e9a commit 474f542
Show file tree
Hide file tree
Showing 18 changed files with 413 additions and 370 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
.tmp
.DS_Store
node.txt
accounts/
keypairs/
test-ledger/

Expand Down Expand Up @@ -157,4 +156,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

scratch
6 changes: 3 additions & 3 deletions examples/limit_order_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from driftpy.types import *
#MarketType, OrderType, OrderParams, PositionDirection, OrderTriggerCondition
from driftpy.accounts import get_perp_market_account, get_spot_market_account
from driftpy.math.oracle import get_oracle_data
from driftpy.accounts.oracle import get_oracle_price_data_and_slot
from driftpy.math.spot_market import get_signed_token_amount, get_token_amount
from driftpy.drift_client import DriftClient
from driftpy.drift_user import DriftUser
Expand Down Expand Up @@ -118,7 +118,7 @@ async def main(
drift_acct.program, market_index
)
try:
oracle_data = await get_oracle_data(connection, market.amm.oracle)
oracle_data = await get_oracle_price_data_and_slot(connection, market.amm.oracle)
current_price = oracle_data.price/PRICE_PRECISION
except:
current_price = market.amm.historical_oracle_data.last_oracle_price/PRICE_PRECISION
Expand All @@ -132,7 +132,7 @@ async def main(
else:
market = await get_spot_market_account( drift_acct.program, market_index)
try:
oracle_data = await get_oracle_data(connection, market.oracle)
oracle_data = await get_oracle_price_data_and_slot(connection, market.oracle)
current_price = oracle_data.price/PRICE_PRECISION
except:
current_price = market.historical_oracle_data.last_oracle_price/PRICE_PRECISION
Expand Down
70 changes: 0 additions & 70 deletions src/driftpy/accounts.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/driftpy/accounts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .get_accounts import *
from .types import *
2 changes: 2 additions & 0 deletions src/driftpy/accounts/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .drift_client import *
from .user import *
78 changes: 78 additions & 0 deletions src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from anchorpy import Program
from solana.publickey import PublicKey
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_state_account_and_slot, get_spot_market_account_and_slot, \
get_perp_market_account_and_slot
from driftpy.accounts.oracle import get_oracle_price_data_and_slot
from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot
from typing import Optional

from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State


class CachedDriftClientAccountSubscriber(DriftClientAccountSubscriber):
def __init__(self, program: Program, commitment: Commitment = "confirmed"):
self.program = program
self.commitment = commitment
self.cache = None

async def update_cache(self):
if self.cache is None:
self.cache = {}

state_and_slot = await get_state_account_and_slot(self.program)
self.cache["state"] = state_and_slot

oracle_data = {}

spot_markets = []
for i in range(state_and_slot.data.number_of_spot_markets):
spot_market_and_slot = await get_spot_market_account_and_slot(self.program, i)
spot_markets.append(spot_market_and_slot)

oracle_price_data_and_slot = await get_oracle_price_data_and_slot(
self.program.provider.connection,
spot_market_and_slot.data.oracle,
spot_market_and_slot.data.oracle_source

)
oracle_data[str(spot_market_and_slot.data.oracle)] = oracle_price_data_and_slot

self.cache["spot_markets"] = spot_markets

perp_markets = []
for i in range(state_and_slot.data.number_of_markets):
perp_market_and_slot = await get_perp_market_account_and_slot(self.program, i)
perp_markets.append(perp_market_and_slot)

oracle_price_data_and_slot = await get_oracle_price_data_and_slot(
self.program.provider.connection,
perp_market_and_slot.data.amm.oracle,
perp_market_and_slot.data.amm.oracle_source
)
oracle_data[str(perp_market_and_slot.data.amm.oracle)] = oracle_price_data_and_slot

self.cache["perp_markets"] = perp_markets

self.cache["oracle_price_data"] = oracle_data

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
await self.cache_if_needed()
return self.cache["state"]

async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]:
await self.cache_if_needed()
return self.cache["perp_markets"][market_index]

async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]:
await self.cache_if_needed()
return self.cache["spot_markets"][market_index]

async def get_oracle_data_and_slot(self, oracle: PublicKey) -> Optional[DataAndSlot[OraclePriceData]]:
await self.cache_if_needed()
return self.cache["oracle_price_data"][str(oracle)]

async def cache_if_needed(self):
if self.cache is None:
await self.update_cache()
29 changes: 29 additions & 0 deletions src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional

from anchorpy import Program
from solana.publickey import PublicKey
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_user_account_and_slot
from driftpy.accounts import UserAccountSubscriber, DataAndSlot
from driftpy.types import User


class CachedUserAccountSubscriber(UserAccountSubscriber):
def __init__(self, user_pubkey: PublicKey, program: Program, commitment: Commitment = "confirmed"):
self.program = program
self.commitment = commitment
self.user_pubkey = user_pubkey
self.user_and_slot = None

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
await self.cache_if_needed()
return self.user_and_slot

async def cache_if_needed(self):
if self.user_and_slot is None:
await self.update_cache()
103 changes: 103 additions & 0 deletions src/driftpy/accounts/get_accounts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import base64
from typing import cast
from solana.publickey import PublicKey
from anchorpy import Program, ProgramAccount
from solana.rpc.commitment import Commitment

from driftpy.types import *
from driftpy.addresses import *
from .types import DataAndSlot, T


async def get_account_data_and_slot(address: PublicKey, program: Program, commitment: Commitment = "processed") -> Optional[
DataAndSlot[T]]:
account_info = await program.provider.connection.get_account_info(
address,
encoding="base64",
commitment=commitment,
)

if not account_info["result"]["value"]:
return None

slot = account_info["result"]["context"]["slot"]
data = base64.b64decode(account_info["result"]["value"]["data"][0])

decoded_data = program.coder.accounts.decode(data)

return DataAndSlot(slot, decoded_data)


async def get_state_account_and_slot(program: Program) -> DataAndSlot[State]:
state_public_key = get_state_public_key(program.program_id)
return await get_account_data_and_slot(state_public_key, program)


async def get_state_account(program: Program) -> State:
return (await get_state_account_and_slot(program)).data


async def get_if_stake_account(
program: Program, authority: PublicKey, spot_market_index: int
) -> InsuranceFundStake:
if_stake_pk = get_insurance_fund_stake_public_key(
program.program_id, authority, spot_market_index
)
response = await program.account["InsuranceFundStake"].fetch(if_stake_pk)
return cast(InsuranceFundStake, response)


async def get_user_stats_account(
program: Program,
authority: PublicKey,
) -> UserStats:
user_stats_public_key = get_user_stats_account_public_key(
program.program_id,
authority,
)
response = await program.account["UserStats"].fetch(user_stats_public_key)
return cast(UserStats, response)

async def get_user_account_and_slot(
program: Program,
user_public_key: PublicKey,
) -> DataAndSlot[User]:
return await get_account_data_and_slot(user_public_key, program)

async def get_user_account(
program: Program,
user_public_key: PublicKey,
) -> User:
return (await get_user_account_and_slot(program, user_public_key)).data


async def get_perp_market_account_and_slot(program: Program, market_index: int) -> Optional[DataAndSlot[PerpMarket]]:
perp_market_public_key = get_perp_market_public_key(program.program_id, market_index)
return await get_account_data_and_slot(perp_market_public_key, program)


async def get_perp_market_account(program: Program, market_index: int) -> PerpMarket:
return (await get_perp_market_account_and_slot(program, market_index)).data


async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount]:
return await program.account["PerpMarket"].all()


async def get_spot_market_account_and_slot(
program: Program, spot_market_index: int
) -> DataAndSlot[SpotMarket]:
spot_market_public_key = get_spot_market_public_key(
program.program_id, spot_market_index
)
return await get_account_data_and_slot(spot_market_public_key, program)


async def get_spot_market_account(
program: Program, spot_market_index: int
) -> SpotMarket:
return (await get_spot_market_account_and_slot(program, spot_market_index)).data


async def get_all_spot_market_accounts(program: Program) -> list[ProgramAccount]:
return await program.account["SpotMarket"].all()
Loading

0 comments on commit 474f542

Please sign in to comment.