Skip to content

Commit

Permalink
Consolidate getting quotes
Browse files Browse the repository at this point in the history
  • Loading branch information
hpeebles committed Sep 19, 2023
1 parent 16fd3dc commit 579c565
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 88 deletions.
44 changes: 13 additions & 31 deletions backend/canisters/exchange_bot/impl/src/commands/quote.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::commands::common_errors::CommonErrors;
use crate::commands::sub_tasks::get_quotes::get_quote;
use crate::commands::sub_tasks::get_quotes::get_quotes;
use crate::commands::{build_error_response, Command, CommandParser, CommandSubTaskResult, ParseMessageResult};
use crate::swap_client::SwapClient;
use crate::{mutate_state, RuntimeState};
Expand All @@ -10,7 +10,6 @@ use rand::Rng;
use regex::{Regex, RegexBuilder};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use types::{MessageContent, MessageId, TimestampMillis, TokenInfo, UserId};

lazy_static! {
Expand Down Expand Up @@ -108,23 +107,13 @@ impl QuoteCommand {

pub(crate) fn process(self, state: &mut RuntimeState) {
let amount = self.amount;
let output_token_decimals = self.output_token.decimals;
let clients: Vec<_> = self
.exchange_ids
.iter()
.filter_map(|e| state.get_swap_client(*e, self.input_token.clone(), self.output_token.clone()))
.collect();

let command = Arc::new(Mutex::new(self));

let futures: Vec<_> = clients
.into_iter()
.map(|c| quote_single(c, amount, output_token_decimals, command.clone()))
.collect();

ic_cdk::spawn(async {
futures::future::join_all(futures).await;
});
ic_cdk::spawn(self.get_quotes(clients, amount));
}

pub fn build_message_text(&self) -> String {
Expand All @@ -142,27 +131,20 @@ impl QuoteCommand {
text
}

async fn get_quotes(mut self, clients: Vec<Box<dyn SwapClient>>, amount: u128) {
get_quotes(clients, amount, self.output_token.decimals, |exchange_id, result| {
self.set_quote_result(exchange_id, result);
let message_text = self.build_message_text();
mutate_state(|state| {
state.enqueue_message_edit(self.user_id, self.message_id, message_text);
});
})
.await
}

fn set_quote_result(&mut self, exchange_id: ExchangeId, result: CommandSubTaskResult<u128>) {
if let Some(r) = self.results.iter_mut().find(|(e, _)| *e == exchange_id).map(|(_, s)| s) {
*r = result;
}
}
}

async fn quote_single(
client: Box<dyn SwapClient>,
amount: u128,
output_token_decimals: u8,
wrapped_command: Arc<Mutex<QuoteCommand>>,
) {
let result = get_quote(client.as_ref(), amount, output_token_decimals).await;

let mut command = wrapped_command.lock().unwrap();
command.set_quote_result(client.exchange_id(), result);

let message_text = command.build_message_text();

mutate_state(|state| {
state.enqueue_message_edit(command.user_id, command.message_id, message_text);
})
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
use crate::commands::CommandSubTaskResult;
use crate::swap_client::SwapClient;
use exchange_bot_canister::ExchangeId;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use ledger_utils::format_crypto_amount;
use std::future::ready;

pub(crate) async fn get_quote(client: &dyn SwapClient, amount: u128, output_token_decimals: u8) -> CommandSubTaskResult<u128> {
pub(crate) async fn get_quotes<C: FnMut(ExchangeId, CommandSubTaskResult<u128>)>(
clients: Vec<Box<dyn SwapClient>>,
amount: u128,
output_token_decimals: u8,
mut callback: C,
) {
let futures = FuturesUnordered::new();
for client in clients {
futures.push(get_quote(client, amount, output_token_decimals));
}

futures
.for_each(|(exchange_id, result)| {
callback(exchange_id, result);
ready(())
})
.await;
}

async fn get_quote(
client: Box<dyn SwapClient>,
amount: u128,
output_token_decimals: u8,
) -> (ExchangeId, CommandSubTaskResult<u128>) {
let exchange_id = client.exchange_id();
let response = client.quote(amount).await;

match response {
let result = match response {
Ok(amount_out) => {
CommandSubTaskResult::Complete(amount_out, Some(format_crypto_amount(amount_out, output_token_decimals)))
}
Err(error) => CommandSubTaskResult::Failed(format!("{error:?}")),
}
};

(exchange_id, result)
}
85 changes: 31 additions & 54 deletions backend/canisters/exchange_bot/impl/src/commands/swap.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::commands::common_errors::CommonErrors;
use crate::commands::sub_tasks::check_user_balance::check_user_balance;
use crate::commands::sub_tasks::get_quotes::get_quote;
use crate::commands::sub_tasks::get_quotes::get_quotes;
use crate::commands::{Command, CommandParser, CommandSubTaskResult, ParseMessageResult};
use crate::swap_client::SwapClient;
use crate::{mutate_state, Data, RuntimeState};
Expand All @@ -10,7 +10,6 @@ use rand::Rng;
use regex::{Regex, RegexBuilder};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use types::icrc1::BlockIndex;
use types::{CanisterId, MessageContent, MessageId, TimestampMillis, TokenInfo, UserId};

Expand Down Expand Up @@ -127,14 +126,17 @@ impl SwapCommand {
pub(crate) fn process(self, state: &mut RuntimeState) {
if self.sub_tasks.check_user_balance.is_pending() {
ic_cdk::spawn(self.check_user_balance(state.env.canister_id()));
} else if let Some(amount) = self.sub_tasks.quotes.is_pending().then_some(self.amount()).flatten() {
let clients: Vec<_> = self
.exchange_ids
.iter()
.filter_map(|e| state.get_swap_client(*e, self.input_token.clone(), self.output_token.clone()))
.collect();

ic_cdk::spawn(get_quotes(self, clients, amount));
} else if let Some(amount) = self.amount() {
if self.sub_tasks.quotes.is_pending() {
let clients: Vec<_> = self
.exchange_ids
.iter()
.filter_map(|e| state.get_swap_client(*e, self.input_token.clone(), self.output_token.clone()))
.collect();

ic_cdk::spawn(self.get_quotes(clients, amount));
} else if self.sub_tasks.transfer_to_dex.is_pending() {
}
}
}

Expand All @@ -154,6 +156,25 @@ impl SwapCommand {
mutate_state(|state| self.on_updated(state));
}

async fn get_quotes(mut self, clients: Vec<Box<dyn SwapClient>>, amount: u128) {
get_quotes(clients, amount, self.output_token.decimals, |exchange_id, result| {
self.set_quote_result(exchange_id, result);
let message_text = self.build_message_text();
mutate_state(|state| {
state.enqueue_message_edit(self.user_id, self.message_id, message_text);
});
})
.await;

if let Some((exchange_id, CommandSubTaskResult::Complete(..))) = self.quotes.iter().max_by_key(|(_, r)| r.value()) {
self.sub_tasks.quotes = CommandSubTaskResult::Complete(*exchange_id, Some(exchange_id.to_string()));
} else {
self.sub_tasks.quotes = CommandSubTaskResult::Failed("Failed to get any valid quotes".to_string());
}

mutate_state(|state| self.on_updated(state));
}

fn on_updated(self, state: &mut RuntimeState) {
let is_finished = self.is_finished();

Expand Down Expand Up @@ -186,50 +207,6 @@ impl SwapCommand {
}
}

async fn get_quotes(command: SwapCommand, clients: Vec<Box<dyn SwapClient>>, amount: u128) {
let output_token_decimals = command.output_token.decimals;
let wrapped_command = Arc::new(Mutex::new(command));

let futures: Vec<_> = clients
.into_iter()
.map(|c| quote_single(c, amount, output_token_decimals, wrapped_command.clone()))
.collect();

futures::future::join_all(futures).await;

let mut command = Arc::try_unwrap(wrapped_command)
.map_err(|_| ())
.unwrap()
.into_inner()
.unwrap();

if let Some((exchange_id, CommandSubTaskResult::Complete(..))) = command.quotes.iter().max_by_key(|(_, r)| r.value()) {
command.sub_tasks.quotes = CommandSubTaskResult::Complete(*exchange_id, Some(exchange_id.to_string()));
} else {
command.sub_tasks.quotes = CommandSubTaskResult::Failed("Failed to get any valid quotes".to_string());
}

mutate_state(|state| command.on_updated(state));
}

async fn quote_single(
client: Box<dyn SwapClient>,
amount: u128,
output_token_decimals: u8,
wrapped_command: Arc<Mutex<SwapCommand>>,
) {
let result = get_quote(client.as_ref(), amount, output_token_decimals).await;

let mut command = wrapped_command.lock().unwrap();
command.set_quote_result(client.exchange_id(), result);

let message_text = command.build_message_text();

mutate_state(|state| {
state.enqueue_message_edit(command.user_id, command.message_id, message_text);
})
}

fn build_error_response(error: CommonErrors, data: &Data) -> ParseMessageResult {
let response_message = error.build_response_message(data);
ParseMessageResult::Error(data.build_text_response(response_message, None))
Expand Down

0 comments on commit 579c565

Please sign in to comment.