diff --git a/backend/canisters/exchange_bot/impl/src/commands/quote.rs b/backend/canisters/exchange_bot/impl/src/commands/quote.rs index a77e12d087..98d23c54e6 100644 --- a/backend/canisters/exchange_bot/impl/src/commands/quote.rs +++ b/backend/canisters/exchange_bot/impl/src/commands/quote.rs @@ -1,5 +1,5 @@ use crate::commands::common_errors::CommonErrors; -use crate::commands::sub_tasks::get_quotes::get_quote; +use crate::commands::sub_tasks::get_quotes::get_quotes; use crate::commands::{build_error_response, Command, CommandParser, CommandSubTaskResult, ParseMessageResult}; use crate::swap_client::SwapClient; use crate::{mutate_state, RuntimeState}; @@ -10,7 +10,6 @@ use rand::Rng; use regex::{Regex, RegexBuilder}; use serde::{Deserialize, Serialize}; use std::str::FromStr; -use std::sync::{Arc, Mutex}; use types::{MessageContent, MessageId, TimestampMillis, TokenInfo, UserId}; lazy_static! { @@ -108,23 +107,13 @@ impl QuoteCommand { pub(crate) fn process(self, state: &mut RuntimeState) { let amount = self.amount; - let output_token_decimals = self.output_token.decimals; let clients: Vec<_> = self .exchange_ids .iter() .filter_map(|e| state.get_swap_client(*e, self.input_token.clone(), self.output_token.clone())) .collect(); - let command = Arc::new(Mutex::new(self)); - - let futures: Vec<_> = clients - .into_iter() - .map(|c| quote_single(c, amount, output_token_decimals, command.clone())) - .collect(); - - ic_cdk::spawn(async { - futures::future::join_all(futures).await; - }); + ic_cdk::spawn(self.get_quotes(clients, amount)); } pub fn build_message_text(&self) -> String { @@ -142,27 +131,20 @@ impl QuoteCommand { text } + async fn get_quotes(mut self, clients: Vec>, amount: u128) { + get_quotes(clients, amount, self.output_token.decimals, |exchange_id, result| { + self.set_quote_result(exchange_id, result); + let message_text = self.build_message_text(); + mutate_state(|state| { + state.enqueue_message_edit(self.user_id, self.message_id, message_text); + }); + }) + .await + } + fn set_quote_result(&mut self, exchange_id: ExchangeId, result: CommandSubTaskResult) { if let Some(r) = self.results.iter_mut().find(|(e, _)| *e == exchange_id).map(|(_, s)| s) { *r = result; } } } - -async fn quote_single( - client: Box, - amount: u128, - output_token_decimals: u8, - wrapped_command: Arc>, -) { - let result = get_quote(client.as_ref(), amount, output_token_decimals).await; - - let mut command = wrapped_command.lock().unwrap(); - command.set_quote_result(client.exchange_id(), result); - - let message_text = command.build_message_text(); - - mutate_state(|state| { - state.enqueue_message_edit(command.user_id, command.message_id, message_text); - }) -} diff --git a/backend/canisters/exchange_bot/impl/src/commands/sub_tasks/get_quotes.rs b/backend/canisters/exchange_bot/impl/src/commands/sub_tasks/get_quotes.rs index 2b510a5204..f377f64cd4 100644 --- a/backend/canisters/exchange_bot/impl/src/commands/sub_tasks/get_quotes.rs +++ b/backend/canisters/exchange_bot/impl/src/commands/sub_tasks/get_quotes.rs @@ -1,14 +1,44 @@ use crate::commands::CommandSubTaskResult; use crate::swap_client::SwapClient; +use exchange_bot_canister::ExchangeId; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use ledger_utils::format_crypto_amount; +use std::future::ready; -pub(crate) async fn get_quote(client: &dyn SwapClient, amount: u128, output_token_decimals: u8) -> CommandSubTaskResult { +pub(crate) async fn get_quotes)>( + clients: Vec>, + amount: u128, + output_token_decimals: u8, + mut callback: C, +) { + let futures = FuturesUnordered::new(); + for client in clients { + futures.push(get_quote(client, amount, output_token_decimals)); + } + + futures + .for_each(|(exchange_id, result)| { + callback(exchange_id, result); + ready(()) + }) + .await; +} + +async fn get_quote( + client: Box, + amount: u128, + output_token_decimals: u8, +) -> (ExchangeId, CommandSubTaskResult) { + let exchange_id = client.exchange_id(); let response = client.quote(amount).await; - match response { + let result = match response { Ok(amount_out) => { CommandSubTaskResult::Complete(amount_out, Some(format_crypto_amount(amount_out, output_token_decimals))) } Err(error) => CommandSubTaskResult::Failed(format!("{error:?}")), - } + }; + + (exchange_id, result) } diff --git a/backend/canisters/exchange_bot/impl/src/commands/swap.rs b/backend/canisters/exchange_bot/impl/src/commands/swap.rs index 92674bad37..61029acbdd 100644 --- a/backend/canisters/exchange_bot/impl/src/commands/swap.rs +++ b/backend/canisters/exchange_bot/impl/src/commands/swap.rs @@ -1,6 +1,6 @@ use crate::commands::common_errors::CommonErrors; use crate::commands::sub_tasks::check_user_balance::check_user_balance; -use crate::commands::sub_tasks::get_quotes::get_quote; +use crate::commands::sub_tasks::get_quotes::get_quotes; use crate::commands::{Command, CommandParser, CommandSubTaskResult, ParseMessageResult}; use crate::swap_client::SwapClient; use crate::{mutate_state, Data, RuntimeState}; @@ -10,7 +10,6 @@ use rand::Rng; use regex::{Regex, RegexBuilder}; use serde::{Deserialize, Serialize}; use std::str::FromStr; -use std::sync::{Arc, Mutex}; use types::icrc1::BlockIndex; use types::{CanisterId, MessageContent, MessageId, TimestampMillis, TokenInfo, UserId}; @@ -127,14 +126,17 @@ impl SwapCommand { pub(crate) fn process(self, state: &mut RuntimeState) { if self.sub_tasks.check_user_balance.is_pending() { ic_cdk::spawn(self.check_user_balance(state.env.canister_id())); - } else if let Some(amount) = self.sub_tasks.quotes.is_pending().then_some(self.amount()).flatten() { - let clients: Vec<_> = self - .exchange_ids - .iter() - .filter_map(|e| state.get_swap_client(*e, self.input_token.clone(), self.output_token.clone())) - .collect(); - - ic_cdk::spawn(get_quotes(self, clients, amount)); + } else if let Some(amount) = self.amount() { + if self.sub_tasks.quotes.is_pending() { + let clients: Vec<_> = self + .exchange_ids + .iter() + .filter_map(|e| state.get_swap_client(*e, self.input_token.clone(), self.output_token.clone())) + .collect(); + + ic_cdk::spawn(self.get_quotes(clients, amount)); + } else if self.sub_tasks.transfer_to_dex.is_pending() { + } } } @@ -154,6 +156,25 @@ impl SwapCommand { mutate_state(|state| self.on_updated(state)); } + async fn get_quotes(mut self, clients: Vec>, amount: u128) { + get_quotes(clients, amount, self.output_token.decimals, |exchange_id, result| { + self.set_quote_result(exchange_id, result); + let message_text = self.build_message_text(); + mutate_state(|state| { + state.enqueue_message_edit(self.user_id, self.message_id, message_text); + }); + }) + .await; + + if let Some((exchange_id, CommandSubTaskResult::Complete(..))) = self.quotes.iter().max_by_key(|(_, r)| r.value()) { + self.sub_tasks.quotes = CommandSubTaskResult::Complete(*exchange_id, Some(exchange_id.to_string())); + } else { + self.sub_tasks.quotes = CommandSubTaskResult::Failed("Failed to get any valid quotes".to_string()); + } + + mutate_state(|state| self.on_updated(state)); + } + fn on_updated(self, state: &mut RuntimeState) { let is_finished = self.is_finished(); @@ -186,50 +207,6 @@ impl SwapCommand { } } -async fn get_quotes(command: SwapCommand, clients: Vec>, amount: u128) { - let output_token_decimals = command.output_token.decimals; - let wrapped_command = Arc::new(Mutex::new(command)); - - let futures: Vec<_> = clients - .into_iter() - .map(|c| quote_single(c, amount, output_token_decimals, wrapped_command.clone())) - .collect(); - - futures::future::join_all(futures).await; - - let mut command = Arc::try_unwrap(wrapped_command) - .map_err(|_| ()) - .unwrap() - .into_inner() - .unwrap(); - - if let Some((exchange_id, CommandSubTaskResult::Complete(..))) = command.quotes.iter().max_by_key(|(_, r)| r.value()) { - command.sub_tasks.quotes = CommandSubTaskResult::Complete(*exchange_id, Some(exchange_id.to_string())); - } else { - command.sub_tasks.quotes = CommandSubTaskResult::Failed("Failed to get any valid quotes".to_string()); - } - - mutate_state(|state| command.on_updated(state)); -} - -async fn quote_single( - client: Box, - amount: u128, - output_token_decimals: u8, - wrapped_command: Arc>, -) { - let result = get_quote(client.as_ref(), amount, output_token_decimals).await; - - let mut command = wrapped_command.lock().unwrap(); - command.set_quote_result(client.exchange_id(), result); - - let message_text = command.build_message_text(); - - mutate_state(|state| { - state.enqueue_message_edit(command.user_id, command.message_id, message_text); - }) -} - fn build_error_response(error: CommonErrors, data: &Data) -> ParseMessageResult { let response_message = error.build_response_message(data); ParseMessageResult::Error(data.build_text_response(response_message, None))