Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
seemenkina committed Aug 9, 2024
1 parent ebe7085 commit c4e469f
Show file tree
Hide file tree
Showing 13 changed files with 445 additions and 838 deletions.
99 changes: 52 additions & 47 deletions ds/src/chat_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,62 +8,82 @@ use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::protocol::Message;

use crate::chat_server::ServerMessage;
use crate::ChatServiceError;
use crate::DeliveryServiceError;

// pub const REQUEST: &str = "You are joining the group with smart contract: ";

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub enum ChatMessages {
Request(RequestMLSPayload),
Response(ResponseMLSPayload),
Welcome(String),
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub enum ReqMessageType {
InviteToGroup,
RemoveFromGroup,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct RequestMLSPayload {
pub msg: String,
sc_address: String,
group_name: String,
pub msg_type: ReqMessageType,
}

impl RequestMLSPayload {
pub fn new(sc_address: String, msg_type: ReqMessageType) -> Self {
pub fn new(sc_address: String, group_name: String, msg_type: ReqMessageType) -> Self {
RequestMLSPayload {
msg: sc_address,
sc_address,
group_name,
msg_type,
}
}

pub fn msg_to_sign(&self) -> String {
self.sc_address.to_owned() + &self.group_name
}

pub fn group_name(&self) -> String {
self.group_name.clone()
}
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct ResponseMLSPayload {
signature: String,
user_address: String,
pub group_name: String,
key_package: Vec<u8>,
}

impl ResponseMLSPayload {
pub fn new(signature: String, user_address: String, key_package: Vec<u8>) -> Self {
pub fn new(
signature: String,
user_address: String,
group_name: String,
key_package: Vec<u8>,
) -> Self {
Self {
signature,
user_address,
group_name,
key_package,
}
}

pub fn validate(&self, sc_address: String) -> Result<(String, Vec<u8>), ChatServiceError> {
pub fn validate(
&self,
sc_address: String,
group_name: String,
) -> Result<(String, Vec<u8>), DeliveryServiceError> {
let recover_sig: Signature = serde_json::from_str(&self.signature)?;
let addr = Address::from_str(&self.user_address)?;
// Recover the signer from the message.
let recovered = recover_sig.recover_address_from_msg(sc_address)?;
let recovered =
recover_sig.recover_address_from_msg(sc_address.to_owned() + &group_name)?;

if recovered.ne(&addr) {
return Err(ChatServiceError::ValidationError);
return Err(DeliveryServiceError::ValidationError(recovered.to_string()));
}
Ok((self.user_address.clone(), self.key_package.clone()))
}
Expand All @@ -75,10 +95,10 @@ pub struct ChatClient {

impl ChatClient {
pub async fn connect(
addr: &str,
username: &str,
) -> Result<(Self, mpsc::UnboundedReceiver<Message>), ChatServiceError> {
let (ws_stream, _) = tokio_tungstenite::connect_async(addr).await?;
server_addr: &str,
username: String,
) -> Result<(Self, mpsc::UnboundedReceiver<Message>), DeliveryServiceError> {
let (ws_stream, _) = tokio_tungstenite::connect_async(server_addr).await?;
let (mut write, read) = ws_stream.split();
let (sender, receiver) = mpsc::unbounded_channel();
let (msg_sender, msg_receiver) = mpsc::unbounded_channel();
Expand All @@ -88,23 +108,22 @@ impl ChatClient {
// Spawn a task to handle outgoing messages
tokio::spawn(async move {
while let Some(message) = receiver.lock().await.recv().await {
// println!("Message from reciever: {}", message);
if let Err(e) = write.send(message).await {
eprintln!("Error sending message: {}", e);
if let Err(err) = write.send(message).await {
return Err(DeliveryServiceError::SenderError(err.to_string()));
}
}
Ok(())
});

// Spawn a task to handle incoming messages
tokio::spawn(async move {
let mut read = read;
while let Some(message) = read.next().await {
if let Ok(msg) = message {
if let Err(e) = msg_sender.send(msg) {
eprintln!("Failed to send message to channel: {}", e);
}
while let Some(Ok(message)) = read.next().await {
if let Err(err) = msg_sender.send(message) {
return Err(DeliveryServiceError::SenderError(err.to_string()));
}
}
Ok(())
});

// Send a SystemJoin message when registering
Expand All @@ -114,26 +133,16 @@ impl ChatClient {
let join_json = serde_json::to_string(&join_msg).unwrap();
sender
.send(Message::Text(join_json))
.map_err(|_| ChatServiceError::SendError)?;
.map_err(|err| DeliveryServiceError::SenderError(err.to_string()))?;

Ok((ChatClient { sender }, msg_receiver))
}

pub async fn send_request(&self, msg: ServerMessage) -> Result<(), ChatServiceError> {
self.send_message_to_server(msg)?;
Ok(())
}

pub async fn handle_response(&self) -> Result<(), ChatServiceError> {
Ok(())
}

pub fn send_message_to_server(&self, msg: ServerMessage) -> Result<(), ChatServiceError> {
pub fn send_message(&self, msg: ServerMessage) -> Result<(), DeliveryServiceError> {
let msg_json = serde_json::to_string(&msg).unwrap();
// println!("Message to sender: {}", msg_json);
self.sender
.send(Message::Text(msg_json))
.map_err(|_| ChatServiceError::SendError)?;
.map_err(|err| DeliveryServiceError::SenderError(err.to_string()))?;
Ok(())
}
}
Expand Down Expand Up @@ -163,6 +172,7 @@ fn test_sign() {
fn json_test() {
let inner_msg = ChatMessages::Request(RequestMLSPayload::new(
"sc_address".to_string(),
"group_name".to_string(),
ReqMessageType::InviteToGroup,
));

Expand All @@ -183,19 +193,14 @@ fn json_test() {
////

if let Ok(chat_message) = serde_json::from_str::<ServerMessage>(&json_server_msg) {
println!("Server: {:?}", chat_message);
assert_eq!(chat_message, server_msg);
match chat_message {
ServerMessage::InMessage { from, to, msg } => {
println!("Chat: {:?}", msg);
if let Ok(chat_msg) = serde_json::from_str::<ChatMessages>(&msg) {
match chat_msg {
ChatMessages::Request(req) => println!("Request: {:?}", req),
ChatMessages::Response(_) => println!("Response"),
ChatMessages::Welcome(_) => println!("Welcome"),
}
assert_eq!(chat_msg, inner_msg);
}
}
ServerMessage::SystemJoin { username } => println!("SystemJoin"),
ServerMessage::SystemJoin { username } => {}
}
}
}
24 changes: 10 additions & 14 deletions ds/src/chat_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use tokio::{
};
use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};

use crate::ChatServiceError;
use crate::DeliveryServiceError;

type Tx = mpsc::UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<String, Tx>>>;

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(tag = "type")]
pub enum ServerMessage {
InMessage {
Expand All @@ -26,7 +26,7 @@ pub enum ServerMessage {
},
}

pub async fn start_server(addr: &str) -> Result<(), ChatServiceError> {
pub async fn start_server(addr: &str) -> Result<(), DeliveryServiceError> {
let listener = TcpListener::bind(addr).await?;
let peers = PeerMap::new(Mutex::new(HashMap::new()));

Expand All @@ -46,7 +46,7 @@ pub async fn start_server(addr: &str) -> Result<(), ChatServiceError> {
async fn handle_connection(
peers: PeerMap,
stream: tokio::net::TcpStream,
) -> Result<(), ChatServiceError> {
) -> Result<(), DeliveryServiceError> {
let ws_stream = accept_async(stream).await?;
let (mut write, mut read) = ws_stream.split();
let (sender, receiver) = mpsc::unbounded_channel();
Expand All @@ -57,16 +57,15 @@ async fn handle_connection(
// Spawn a task to handle outgoing messages
tokio::spawn(async move {
while let Some(message) = receiver.lock().await.recv().await {
println!("raw message out: {}", message);
if let Err(e) = write.send(message).await {
eprintln!("Error sending message: {}", e);
if let Err(err) = write.send(message).await {
return Err(DeliveryServiceError::SenderError(err.to_string()));
}
}
Ok(())
});

// Handle incoming messages
while let Some(Ok(Message::Text(text))) = read.next().await {
println!("raw message in {}", text);
if let Ok(chat_message) = serde_json::from_str::<ServerMessage>(&text) {
match chat_message {
ServerMessage::SystemJoin {
Expand All @@ -81,12 +80,7 @@ async fn handle_connection(
}
ServerMessage::InMessage { from, to, msg } => {
println!("Received message from {} to {:?}: {}", from, to, msg);
println!(
"\t got contact list {:?}",
peers.lock().await.keys().collect::<Vec<&String>>()
);
for recipient in to {
println!("\t rcpt {}", recipient);
if let Some(recipient_sender) = peers.lock().await.get(&recipient) {
let message = ServerMessage::InMessage {
from: from.clone(),
Expand All @@ -96,7 +90,9 @@ async fn handle_connection(
let message_json = serde_json::to_string(&message).unwrap();
recipient_sender
.send(Message::Text(message_json))
.map_err(|_| ChatServiceError::SendError)?;
.map_err(|err| {
DeliveryServiceError::SenderError(err.to_string())
})?;
}
}
}
Expand Down
Loading

0 comments on commit c4e469f

Please sign in to comment.