Skip to content

Commit

Permalink
frank/resurrection (#175)
Browse files Browse the repository at this point in the history
* 0.7.57

* resurrect unpickle into working sdk

* resurrect unpickle into working sdk
  • Loading branch information
soundsonacid authored Jun 18, 2024
1 parent 76428e2 commit 6568bb9
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.7.57
current_version = 0.7.58
commit = True
tag = True
tag_name = {new_version}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "driftpy"
version = "0.7.57"
version = "0.7.58"
description = "A Python client for the Drift DEX"
authors = ["x19 <https://twitter.com/[email protected]>", "bigz <https://twitter.com/bigz_pubkey>", "frank <https://twitter.com/soundsonacid>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion src/driftpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.57"
__version__ = "0.7.58"
38 changes: 34 additions & 4 deletions src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional

from anchorpy import Program

Expand All @@ -14,6 +14,8 @@
get_perp_market_account_and_slot,
)
from driftpy.constants.numeric_constants import QUOTE_SPOT_MARKET_INDEX

# from driftpy.market_map.market_map import MarketMap
from driftpy.types import (
OracleInfo,
PerpMarketAccount,
Expand All @@ -36,7 +38,7 @@ def __init__(
):
self.program = program
self.commitment = commitment
self.cache = None
self.cache = {"spot_markets": {}, "perp_markets": {}, "oracle_price_data": {}}
self.perp_market_indexes = perp_market_indexes
self.spot_market_indexes = spot_market_indexes
self.oracle_infos = oracle_infos
Expand All @@ -46,8 +48,13 @@ async def subscribe(self):
await self.update_cache()

async def update_cache(self):
if self.cache is None:
self.cache = {}
is_empty = all(not d for d in self.cache.values())
if is_empty:
self.cache = {
"spot_markets": {},
"perp_markets": {},
"oracle_price_data": {},
}

state_and_slot = await get_state_account_and_slot(self.program)
self.cache["state"] = state_and_slot
Expand Down Expand Up @@ -147,6 +154,29 @@ async def update_cache(self):
async def fetch(self):
await self.update_cache()

def resurrect(
self,
spot_markets, # MarketMap
perp_markets, # MarketMap
spot_oracles: dict[int, OraclePriceData],
perp_oracles: dict[int, OraclePriceData],
):
sort_markets = lambda markets: sorted(
markets.values(), key=lambda market: market.data.market_index
)
self.cache["spot_markets"] = sort_markets(spot_markets)
self.cache["perp_markets"] = sort_markets(perp_markets)

for market_index, oracle_price_data in spot_oracles.items():
corresponding_market = self.cache["spot_markets"][market_index]
oracle_pubkey = corresponding_market.oracle
self.cache["oracle_price_data"][str(oracle_pubkey)] = oracle_price_data

for market_index, oracle_price_data in perp_oracles.items():
corresponding_market = self.cache["perp_markets"][market_index]
oracle_pubkey = corresponding_market.amm.oracle
self.cache["oracle_price_data"][str(oracle_pubkey)] = oracle_price_data

def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.cache["state"]

Expand Down
11 changes: 11 additions & 0 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,17 @@ async def subscribe(self):
for sub_account_id in self.sub_account_ids:
await self.add_user(sub_account_id)

def resurrect(self, spot_markets, perp_markets, spot_oracles, perp_oracles):
from driftpy.accounts.cache import CachedDriftClientAccountSubscriber

if not isinstance(self.account_subscriber, CachedDriftClientAccountSubscriber):
raise ValueError(
'You can only resurrect a DriftClient that was initialized with AccountSubscriptionConfig("cached")'
)
self.account_subscriber.resurrect(
spot_markets, perp_markets, spot_oracles, perp_oracles
)

async def add_user(self, sub_account_id: int):
if sub_account_id in self.users:
return
Expand Down
19 changes: 11 additions & 8 deletions src/driftpy/pickle/vat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pickle
from driftpy.drift_client import DriftClient
from driftpy.market_map.market_map import MarketMap
from driftpy.types import OraclePriceData, PickledData
from driftpy.types import PickledData
from driftpy.user_map.user_map import UserMap
from driftpy.user_map.userstats_map import UserStatsMap

Expand All @@ -22,8 +22,8 @@ def __init__(
self.spot_markets = spot_markets
self.perp_markets = perp_markets
self.last_oracle_slot = 0
self.market_index_to_perp_price = {}
self.market_index_to_spot_price = {}
self.perp_oracles = {}
self.spot_oracles = {}

async def pickle(self):
await self.users.sync()
Expand All @@ -48,6 +48,9 @@ async def unpickle(self):
await self.spot_markets.load()
await self.perp_markets.load()

self.drift_client.resurrect(
self.spot_markets, self.perp_markets, self.spot_oracles, self.perp_oracles
)
self.load_oracles()

async def dump_oracles(self):
Expand All @@ -62,14 +65,14 @@ async def dump_oracles(self):

spot_oracles = []
for market in self.drift_client.get_spot_market_accounts():
oracle_price = self.drift_client.get_oracle_price_data_for_spot_market(
oracle_price_data = self.drift_client.get_oracle_price_data_for_spot_market(
market.market_index
)
spot_oracles.append(
PickledData(pubkey=market.market_index, data=oracle_price)
PickledData(pubkey=market.market_index, data=oracle_price_data)
)

self.last_oracle_slot = await self.drift_client.connection.get_slot()
self.last_oracle_slot = (await self.drift_client.connection.get_slot()).value

with open(f"perporacles_{self.last_oracle_slot}.pkl", "wb") as f:
pickle.dump(perp_oracles, f)
Expand All @@ -81,9 +84,9 @@ def load_oracles(self):
with open(f"perporacles_{self.last_oracle_slot}.pkl", "rb") as f:
perp_oracles: list[PickledData] = pickle.load(f)
for oracle in perp_oracles:
self.market_index_to_perp_price[oracle.pubkey] = oracle.data
self.perp_oracles[oracle.pubkey] = oracle.data

with open(f"spotoracles_{self.last_oracle_slot}.pkl", "rb") as f:
spot_oracles: list[PickledData] = pickle.load(f)
for oracle in spot_oracles:
self.market_index_to_spot_price[oracle.pubkey] = oracle.data
self.spot_oracles[oracle.pubkey] = oracle.data

0 comments on commit 6568bb9

Please sign in to comment.