Skip to content

Commit

Permalink
feat: unsubscribe from swap ids in WebSocket (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
michael1011 authored Nov 29, 2024
1 parent 3e41c1c commit d273c04
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 56 deletions.
9 changes: 5 additions & 4 deletions boltzr/src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use crate::api::sse::sse_handler;
use crate::ws::status::SwapInfos;
use crate::ws::types::SwapStatus;
use axum::routing::get;
use axum::{Extension, Router};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use ws::status::SwapInfos;
use ws::types::SwapStatus;

#[cfg(feature = "metrics")]
use crate::metrics::server::MetricsLayer;

mod sse;
pub mod ws;

#[derive(Deserialize, Serialize, PartialEq, Clone, Debug)]
pub struct Config {
Expand Down Expand Up @@ -100,9 +101,9 @@ where

#[cfg(test)]
mod test {
use crate::api::ws::status::SwapInfos;
use crate::api::ws::types::SwapStatus;
use crate::api::{Config, Server};
use crate::ws::status::SwapInfos;
use crate::ws::types::SwapStatus;
use async_trait::async_trait;
use reqwest::StatusCode;
use std::time::Duration;
Expand Down
4 changes: 2 additions & 2 deletions boltzr/src/api/sse.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::api::ws::status::SwapInfos;
use crate::api::ServerState;
use crate::ws::status::SwapInfos;
use async_stream::try_stream;
use axum::response::sse::{Event, Sse};
use axum::{extract::Query, Extension};
Expand Down Expand Up @@ -75,7 +75,7 @@ where
#[cfg(test)]
mod test {
use crate::api::test::start;
use crate::ws::types::SwapStatus;
use crate::api::ws::types::SwapStatus;
use eventsource_client::{Client, SSE};
use futures_util::StreamExt;

Expand Down
File renamed without changes.
203 changes: 166 additions & 37 deletions boltzr/src/ws/status.rs → boltzr/src/api/ws/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn};

use crate::ws::types::SwapStatus;
use crate::ws::Config;
use crate::api::ws::types::SwapStatus;
use crate::api::ws::Config;

const PING_INTERVAL_MS: u64 = 15_000;

Expand Down Expand Up @@ -43,35 +43,59 @@ enum SubscriptionChannel {
}

#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct SubscribeMessage {
struct SubscribeRequest {
channel: SubscriptionChannel,
args: Vec<String>,
}

#[serde(skip_serializing_if = "Option::is_none")]
timestamp: Option<String>,
#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct UnsubscribeRequest {
channel: SubscriptionChannel,
args: Vec<String>,
}

#[derive(Deserialize, Debug)]
#[serde(tag = "op")]
enum WsRequest {
#[serde(rename = "subscribe")]
Subscribe(SubscribeMessage),
Subscribe(SubscribeRequest),
#[serde(rename = "unsubscribe")]
Unsubscribe(UnsubscribeRequest),
}

#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct SubscribeResponse {
channel: SubscriptionChannel,
args: Vec<String>,

timestamp: String,
}

#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct UnsubscribeResponse {
channel: SubscriptionChannel,
args: Vec<String>,

timestamp: String,
}

#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct UpdateMessage {
struct UpdateResponse {
channel: SubscriptionChannel,
args: Vec<SwapStatus>,

timestamp: String,
}

#[derive(Deserialize, Serialize, Debug, PartialEq)]
#[serde(tag = "event")]
enum WsResponse {
#[serde(rename = "subscribe")]
Subscribe(SubscribeMessage),
Subscribe(SubscribeResponse),
#[serde(rename = "unsubscribe")]
Unsubscribe(UnsubscribeResponse),
#[serde(rename = "update")]
Update(UpdateMessage),
Update(UpdateResponse),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -224,7 +248,7 @@ where
}
};

let msg = match serde_json::to_string(&WsResponse::Update(UpdateMessage {
let msg = match serde_json::to_string(&WsResponse::Update(UpdateResponse {
timestamp,
channel: SubscriptionChannel::SwapUpdate,
args: relevant_updates,
Expand Down Expand Up @@ -276,38 +300,48 @@ where
};
trace!("Got message: {:?}", msg);

let get_timestamp = || match Self::get_timestamp() {
Ok(res) => Some(res),
Err(err) => {
error!("Could not get UNIX time: {}", err);
None
}
};

match msg {
WsRequest::Subscribe(sub) => match sub.channel {
SubscriptionChannel::SwapUpdate => {
self.subscribe_swap_updates(subscribed_ids, &sub.args).await;
for id in &sub.args {
subscribed_ids.insert(id.clone());
}

let timestamp = match Self::get_timestamp() {
Ok(res) => res,
Err(err) => {
error!("Could not get UNIX time: {}", err);
return Ok(None);
}
};
Ok(Some(WsResponse::Subscribe(SubscribeMessage {
timestamp: Some(timestamp),
channel: SubscriptionChannel::SwapUpdate,
self.swap_infos.fetch_status_info(&sub.args).await;

Ok(Some(WsResponse::Subscribe(SubscribeResponse {
timestamp: match get_timestamp() {
Some(time) => time,
None => return Ok(None),
},
args: sub.args,
channel: SubscriptionChannel::SwapUpdate,
})))
}
},
}
}
WsRequest::Unsubscribe(unsub) => {
for id in &unsub.args {
subscribed_ids.remove(id);
}

async fn subscribe_swap_updates(
&self,
subscribed_ids: &mut HashSet<String>,
ids: &Vec<String>,
) {
for id in ids {
subscribed_ids.insert(id.clone());
Ok(Some(WsResponse::Unsubscribe(UnsubscribeResponse {
timestamp: match get_timestamp() {
Some(time) => time,
None => return Ok(None),
},
channel: SubscriptionChannel::SwapUpdate,
args: subscribed_ids.iter().cloned().collect(),
})))
}
}

self.swap_infos.fetch_status_info(ids).await;
}

fn get_timestamp() -> Result<String, SystemTimeError> {
Expand All @@ -320,9 +354,11 @@ where

#[cfg(test)]
mod status_test {
use crate::ws::status::{ErrorResponse, Status, SubscriptionChannel, SwapInfos, WsResponse};
use crate::ws::types::SwapStatus;
use crate::ws::Config;
use crate::api::ws::status::{
ErrorResponse, Status, SubscriptionChannel, SwapInfos, WsResponse,
};
use crate::api::ws::types::SwapStatus;
use crate::api::ws::Config;
use async_trait::async_trait;
use async_tungstenite::tungstenite::Message;
use futures::{SinkExt, StreamExt};
Expand Down Expand Up @@ -479,7 +515,7 @@ mod status_test {
assert_eq!(res.channel, SubscriptionChannel::SwapUpdate);
assert_eq!(res.args, vec!["some".to_string(), "ids".to_string(),]);
assert!(
res.timestamp.unwrap().parse::<u128>().unwrap()
res.timestamp.parse::<u128>().unwrap()
<= SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
Expand Down Expand Up @@ -550,7 +586,7 @@ mod status_test {
"ids".into(),
"invoice.set".into(),
)])
.unwrap()
.unwrap();
});

let mut count = 0;
Expand Down Expand Up @@ -592,6 +628,99 @@ mod status_test {
cancel.cancel();
}

#[tokio::test]
async fn test_unsubscribe() {
let port = 12_006;
let (cancel, update_tx) = create_server(port).await;

let (client, _) =
async_tungstenite::tokio::connect_async(format!("ws://127.0.0.1:{}", port))
.await
.unwrap();

let (mut tx, mut rx) = client.split();

tokio::spawn(async move {
tx.send(Message::Text(
json!({
"op": "subscribe",
"channel": "swap.update",
"args": vec!["some", "ids"],
})
.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;

update_tx
.send(vec![SwapStatus::default(
"ids".into(),
"invoice.set".into(),
)])
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;

tx.send(Message::Text(
json!({
"op": "unsubscribe",
"channel": "swap.update",
"args": vec!["ids"],
})
.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;

update_tx
.send(vec![SwapStatus::default(
"ids".into(),
"transaction.mempool".into(),
)])
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;

cancel.cancel();
});

let mut update_count = 0;
let mut unsubscribe_sent = false;

loop {
let msg = match rx.next().await {
Some(msg) => match msg {
Ok(msg) => msg,
Err(_) => break,
},
None => continue,
};
if !msg.is_text() {
continue;
}

let res = serde_json::from_str::<WsResponse>(msg.to_text().unwrap()).unwrap();

match res {
WsResponse::Update(_) => {
update_count += 1;
}
WsResponse::Unsubscribe(msg) => {
assert_eq!(msg.channel, SubscriptionChannel::SwapUpdate);
assert_eq!(msg.args, vec!["some".to_string()]);

unsubscribe_sent = true;
}
_ => {}
}
}

assert!(unsubscribe_sent);

// One for the initial update and one for the update that was sent before the unsubscribe
assert_eq!(update_count, 2);
}

async fn create_server(port: u16) -> (CancellationToken, Sender<Vec<SwapStatus>>) {
let cancel = CancellationToken::new();
let (status_tx, _status_rx) = tokio::sync::broadcast::channel::<Vec<SwapStatus>>(1);
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions boltzr/src/grpc/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::api::ws::types::SwapStatus;
use crate::db::helpers::web_hook::WebHookHelper;
use crate::evm::refund_signer::RefundSigner;
use crate::grpc::service::boltzr::boltz_r_server::BoltzRServer;
Expand All @@ -8,7 +9,6 @@ use crate::notifications::NotificationClient;
use crate::swap::manager::SwapManager;
use crate::tracing_setup::ReloadHandler;
use crate::webhook::caller::Caller;
use crate::ws::types::SwapStatus;
use serde::{Deserialize, Serialize};
use std::cell::Cell;
use std::error::Error;
Expand Down Expand Up @@ -150,6 +150,7 @@ where

#[cfg(test)]
mod server_test {
use crate::api::ws;
use crate::chain::utils::Transaction;
use crate::currencies::Currency;
use crate::db::helpers::web_hook::WebHookHelper;
Expand All @@ -163,7 +164,6 @@ mod server_test {
use crate::swap::manager::SwapManager;
use crate::tracing_setup::ReloadHandler;
use crate::webhook::caller;
use crate::ws;
use alloy::primitives::{Address, FixedBytes, Signature, U256};
use async_trait::async_trait;
use mockall::{mock, predicate::*};
Expand Down
4 changes: 2 additions & 2 deletions boltzr/src/grpc/service.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::api::ws::types::SwapStatus;
use crate::db::helpers::web_hook::WebHookHelper;
use crate::db::models::{WebHook, WebHookState};
use crate::evm::refund_signer::RefundSigner;
Expand All @@ -19,7 +20,6 @@ use crate::notifications::NotificationClient;
use crate::swap::manager::SwapManager;
use crate::tracing_setup::ReloadHandler;
use crate::webhook::caller::Caller;
use crate::ws::types::SwapStatus;
use alloy::primitives::{Address, FixedBytes};
use futures::StreamExt;
use lightning::blinded_path::IntroductionNode;
Expand Down Expand Up @@ -619,6 +619,7 @@ fn extract_parent_context<T>(request: &Request<T>) {

#[cfg(test)]
mod test {
use crate::api::ws;
use crate::chain::utils::Transaction;
use crate::currencies::Currency;
use crate::db::helpers::web_hook::WebHookHelper;
Expand All @@ -637,7 +638,6 @@ mod test {
use crate::swap::manager::SwapManager;
use crate::tracing_setup::ReloadHandler;
use crate::webhook::caller::{Caller, Config};
use crate::ws;
use alloy::primitives::{Address, FixedBytes, Signature, U256};
use alloy::signers::k256;
use async_trait::async_trait;
Expand Down
Loading

0 comments on commit d273c04

Please sign in to comment.