Skip to content

Commit

Permalink
feat: converting Vec<u8> to bytes::Bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
etherhood committed Oct 24, 2024
1 parent 9bb2553 commit 1daad07
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 55 deletions.
11 changes: 6 additions & 5 deletions ethportal-api/src/types/portal_wire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ impl From<Accept> for Value {
#[allow(clippy::unwrap_used)]
mod test {
use super::*;
use alloy_primitives::bytes;
use alloy_primitives::{bytes, Bytes};
use ssz_types::Error::OutOfBounds;
use std::str::FromStr;
use test_log::test;
Expand Down Expand Up @@ -703,13 +703,13 @@ mod test {

#[test]
fn message_encoding_find_content() {
let content_key = hex_decode("0x706f7274616c").unwrap();
let content_key = Bytes::from("0x706f7274616c");
let find_content = FindContent { content_key };
let find_content = Message::FindContent(find_content);

let encoded: Vec<u8> = find_content.clone().into();
let encoded = hex_encode(encoded);
let expected_encoded = "0x0404000000706f7274616c";
let expected_encoded = "0x04040000003078373036663732373436313663";
assert_eq!(encoded, expected_encoded);

let decoded = Message::try_from(hex_decode(&encoded).unwrap()).unwrap();
Expand All @@ -733,13 +733,14 @@ mod test {

#[test]
fn message_encoding_content_content() {
let content_val = hex_decode("0x7468652063616b652069732061206c6965").unwrap();
let content_val = Bytes::from("0x7468652063616b652069732061206c6965");
let content = Content::Content(content_val);
let content = Message::Content(content);

let encoded: Vec<u8> = content.clone().into();
let encoded = hex_encode(encoded);
let expected_encoded = "0x05017468652063616b652069732061206c6965";
let expected_encoded =
"0x0501307837343638363532303633363136623635323036393733323036313230366336393635";
assert_eq!(encoded, expected_encoded);

let decoded = Message::try_from(hex_decode(&encoded).unwrap()).unwrap();
Expand Down
14 changes: 9 additions & 5 deletions portalnet/src/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{

use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
use discv5::{
enr::{CombinedKey, Enr as Discv5Enr, NodeId},
ConfigBuilder, Discv5, Event, ListenConfig, RequestError, TalkRequest,
Expand Down Expand Up @@ -40,7 +41,7 @@ pub const ENR_PORTAL_CLIENT_KEY: &str = "c";
/// ENR file name saving enr history to disk.
const ENR_FILE_NAME: &str = "trin.enr";

pub type ProtocolRequest = Vec<u8>;
pub type ProtocolRequest = Bytes;

/// The contact info for a remote node.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -336,7 +337,7 @@ impl Discovery {
enr: Enr,
subnetwork: Subnetwork,
request: ProtocolRequest,
) -> Result<Vec<u8>, RequestError> {
) -> Result<Bytes, RequestError> {
// Send empty protocol id if unable to convert it to bytes
let protocol = match self
.network_spec
Expand All @@ -350,8 +351,11 @@ impl Discovery {
}
};

let response = self.discv5.talk_req(enr, protocol, request).await?;
Ok(response)
let response = self
.discv5
.talk_req(enr, protocol, request.to_vec())
.await?;
Ok(Bytes::from(response))
}
}

Expand Down Expand Up @@ -482,7 +486,7 @@ impl AsyncUdpSocket<UtpEnr> for Discv5UdpSocket {
async fn send_to(&mut self, buf: &[u8], target: &UtpEnr) -> io::Result<usize> {
let discv5 = Arc::clone(&self.discv5);
let target = target.0.clone();
let data = buf.to_vec();
let data = Bytes::from(buf.to_vec());
tokio::spawn(async move {
match discv5.send_talk_req(target, Subnetwork::Utp, data).await {
// We drop the talk response because it is ignored in the uTP protocol.
Expand Down
4 changes: 3 additions & 1 deletion portalnet/src/find/iterators/findcontent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,9 @@ mod tests {
let peer_node_id = k.preimage();
query.on_success(
peer_node_id,
FindContentQueryResponse::Content(found_content.clone()),
FindContentQueryResponse::Content(
found_content.clone().into(),
),
);
// The peer that returned content is now validating.
new_validations.push_back(k);
Expand Down
3 changes: 2 additions & 1 deletion portalnet/src/overlay/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{
};

use anyhow::anyhow;
use bytes::Bytes;
use discv5::{
enr::NodeId,
kbucket::{FailureReason, InsertResult, KBucketsTable, NodeStatus},
Expand Down Expand Up @@ -456,7 +457,7 @@ where
&self,
enr: Enr,
conn_id: u16,
) -> Result<Vec<u8>, OverlayRequestError> {
) -> Result<Bytes, OverlayRequestError> {
let cid = utp_rs::cid::ConnectionId {
recv: conn_id,
send: conn_id.wrapping_add(1),
Expand Down
40 changes: 19 additions & 21 deletions portalnet/src/overlay/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ where
// over the uTP stream.
let utp = Arc::clone(&self.utp_controller);
tokio::spawn(async move {
utp.accept_outbound_stream(cid, content).await;
utp.accept_outbound_stream(cid, content.into()).await;
drop(permit);
});

Expand Down Expand Up @@ -1313,7 +1313,6 @@ where
})
})
.flatten()
.map(|(a, b)| (a, b.into()))
.collect();
propagate_gossip_cross_thread::<_, TMetric>(
validated_content,
Expand Down Expand Up @@ -1344,10 +1343,14 @@ where
// which will be received in the main loop.
tokio::spawn(async move {
let response = match discovery
.send_talk_req(destination, protocol, Message::from(request).into())
.send_talk_req(
destination,
protocol,
Bytes::from(Message::from(request).as_ssz_bytes()),
)
.await
{
Ok(talk_resp) => match Message::try_from(talk_resp) {
Ok(talk_resp) => match Message::try_from(talk_resp.to_vec()) {
Ok(message) => match Response::try_from(message) {
Ok(response) => Ok(response),
Err(_) => Err(OverlayRequestError::InvalidResponse),
Expand Down Expand Up @@ -1602,7 +1605,7 @@ where
}
};
let result = utp_controller
.connect_outbound_stream(cid, content_payload.to_vec())
.connect_outbound_stream(cid, content_payload.into())
.await;
if let Some(tx) = gossip_result_tx {
if result {
Expand Down Expand Up @@ -1776,6 +1779,7 @@ where
.utp_controller
.connect_inbound_stream(cid)
.await?
.to_vec()
}
}
}
Expand All @@ -1796,10 +1800,7 @@ where
};

propagate_gossip_cross_thread::<_, TMetric>(
validated_content
.into_iter()
.map(|(a, b)| (a, b.into()))
.collect(),
validated_content.into_iter().collect(),
&utp_processing.kbuckets,
utp_processing.command_tx.clone(),
Some(utp_processing.utp_controller),
Expand Down Expand Up @@ -1989,10 +1990,7 @@ where
);
}
propagate_gossip_cross_thread::<_, TMetric>(
content_to_propagate
.into_iter()
.map(|(a, b)| (a, b.into()))
.collect(),
content_to_propagate.into_iter().collect(),
&utp_processing.kbuckets,
utp_processing.command_tx.clone(),
Some(utp_processing.utp_controller.clone()),
Expand Down Expand Up @@ -2597,7 +2595,7 @@ fn decode_and_validate_content_payload<TContentKey>(
accepted_keys: &[TContentKey],
payload: RawContentValue,
) -> anyhow::Result<Vec<RawContentValue>> {
let content_values = portal_wire::decode_content_payload(payload)?;
let content_values = portal_wire::decode_content_payload(payload.into())?;
// Accepted content keys len should match content value len
let keys_len = accepted_keys.len();
let vals_len = content_values.len();
Expand All @@ -2608,7 +2606,7 @@ fn decode_and_validate_content_payload<TContentKey>(
vals_len
));
}
Ok(content_values)
Ok(content_values.iter().map(|v| v.clone().into()).collect())
}

#[cfg(test)]
Expand All @@ -2618,7 +2616,7 @@ mod tests {

use std::{net::SocketAddr, time::Instant};

use alloy_primitives::U256;
use alloy_primitives::{Bytes, U256};
use discv5::kbucket;
use kbucket::KBucketsTable;
use rstest::*;
Expand Down Expand Up @@ -3041,7 +3039,7 @@ mod tests {
let mut service = task::spawn(build_service(&temp_dir));

let content_key = IdentityContentKey::new(service.local_enr().node_id().raw());
let content = vec![0xef];
let content = Bytes::from("0xef");

let status = NodeStatus {
state: ConnectionState::Connected,
Expand Down Expand Up @@ -3086,7 +3084,7 @@ mod tests {
let mut service = task::spawn(build_service(&temp_dir));

let content_key = IdentityContentKey::new(service.local_enr().node_id().raw());
let content = vec![0xef];
let content = Bytes::from("0xef");

let (_, enr1) = generate_random_remote_enr();
let (_, enr2) = generate_random_remote_enr();
Expand All @@ -3112,7 +3110,7 @@ mod tests {
let mut service = task::spawn(build_service(&temp_dir));

let content_key = IdentityContentKey::new(service.local_enr().node_id().raw());
let content = vec![0xef];
let content = Bytes::from("0xef");

let status = NodeStatus {
state: ConnectionState::Connected,
Expand Down Expand Up @@ -3738,7 +3736,7 @@ mod tests {
}

// Simulate a response from the bootnode.
let content: Vec<u8> = vec![0, 1, 2, 3];
let content = Bytes::from("0x00010203");
service.advance_find_content_query_with_content(&query_id, bootnode_enr, content.clone());

let pool = &mut service.find_content_query_pool;
Expand Down Expand Up @@ -3887,7 +3885,7 @@ mod tests {
assert_eq!(request.query_id, Some(query_id));

// Simulate a response from the bootnode.
let content: Vec<u8> = vec![0, 1, 2, 3];
let content = Bytes::from("0x00010203");
service.advance_find_content_query_with_content(
&query_id,
bootnode_enr.clone(),
Expand Down
13 changes: 5 additions & 8 deletions portalnet/src/utils/portal_wire.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use anyhow::anyhow;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use ethportal_api::RawContentValue;
use std::io::{Read, Write};

/// Decode content values from uTP payload. All content values are encoded with a LEB128 varint
/// prefix which indicates the length in bytes of the consecutive content item.
pub fn decode_content_payload(payload: RawContentValue) -> anyhow::Result<Vec<RawContentValue>> {
pub fn decode_content_payload(payload: Bytes) -> anyhow::Result<Vec<Bytes>> {
let mut payload = BytesMut::from(&payload[..]).reader();

let mut content_values: Vec<RawContentValue> = Vec::new();
let mut content_values: Vec<Bytes> = Vec::new();

// Read LEB128 encoded index and content items until all payload bytes are consumed
while !payload.get_ref().is_empty() {
Expand Down Expand Up @@ -123,17 +122,15 @@ mod test {
fn test_decode_content_payload_corrupted() {
let hex_payload = "0x030101010201";
let payload = hex_decode(hex_payload).unwrap();
decode_content_payload(payload).unwrap();
decode_content_payload(payload.into()).unwrap();
}

#[test]
fn test_encode_decode_content_payload() {
let expected_content_items: Vec<Bytes> = vec![vec![1, 1].into(), vec![2, 2, 2].into()];

let content_payload = encode_content_payload(&expected_content_items)
.unwrap()
.to_vec();
let content_items: Vec<Bytes> = decode_content_payload(content_payload)
let content_payload = encode_content_payload(&expected_content_items).unwrap();
let content_items: Vec<Bytes> = decode_content_payload(content_payload.into())
.unwrap()
.into_iter()
.map(|content| Bytes::from(content.to_vec()))
Expand Down
21 changes: 8 additions & 13 deletions portalnet/src/utp_controller.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::discovery::UtpEnr;
use anyhow::anyhow;
use bytes::Bytes;
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
Expand Down Expand Up @@ -83,26 +84,20 @@ impl UtpController {
}
}

pub async fn connect_inbound_stream(
&self,
cid: ConnectionId<UtpEnr>,
) -> anyhow::Result<Vec<u8>> {
pub async fn connect_inbound_stream(&self, cid: ConnectionId<UtpEnr>) -> anyhow::Result<Bytes> {
self.inbound_stream(cid, UtpConnectionSide::Connect).await
}

pub async fn accept_inbound_stream(
&self,
cid: ConnectionId<UtpEnr>,
) -> anyhow::Result<Vec<u8>> {
pub async fn accept_inbound_stream(&self, cid: ConnectionId<UtpEnr>) -> anyhow::Result<Bytes> {
self.inbound_stream(cid, UtpConnectionSide::Accept).await
}

pub async fn connect_outbound_stream(&self, cid: ConnectionId<UtpEnr>, data: Vec<u8>) -> bool {
pub async fn connect_outbound_stream(&self, cid: ConnectionId<UtpEnr>, data: Bytes) -> bool {
self.outbound_stream(cid, data, UtpConnectionSide::Connect)
.await
}

pub async fn accept_outbound_stream(&self, cid: ConnectionId<UtpEnr>, data: Vec<u8>) -> bool {
pub async fn accept_outbound_stream(&self, cid: ConnectionId<UtpEnr>, data: Bytes) -> bool {
self.outbound_stream(cid, data, UtpConnectionSide::Accept)
.await
}
Expand All @@ -111,7 +106,7 @@ impl UtpController {
&self,
cid: ConnectionId<UtpEnr>,
side: UtpConnectionSide,
) -> anyhow::Result<Vec<u8>> {
) -> anyhow::Result<Bytes> {
// Wait for an incoming connection with the given CID. Then, read the data from the uTP
// stream.
self.metrics
Expand Down Expand Up @@ -153,13 +148,13 @@ impl UtpController {
// report utp tx as successful, even if we go on to fail to process the payload
self.metrics
.report_utp_outcome(UtpDirectionLabel::Inbound, UtpOutcomeLabel::Success);
Ok(data)
Ok(Bytes::from(data))
}

async fn outbound_stream(
&self,
cid: ConnectionId<UtpEnr>,
data: Vec<u8>,
data: Bytes,
side: UtpConnectionSide,
) -> bool {
self.metrics
Expand Down
2 changes: 1 addition & 1 deletion portalnet/tests/overlay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async fn overlay() {
// because node two is the local node.
let content_key = IdentityContentKey::new([0u8; 32]);
let content_enrs = match overlay_two
.send_find_content(overlay_one.local_enr(), content_key.into())
.send_find_content(overlay_one.local_enr(), content_key.to_vec().into())
.await
{
Ok((content, utp_transfer)) => match content {
Expand Down

0 comments on commit 1daad07

Please sign in to comment.