diff --git a/src/google_pubsub_client.rs b/src/google_pubsub_client.rs index 8fabc50..aa0d802 100644 --- a/src/google_pubsub_client.rs +++ b/src/google_pubsub_client.rs @@ -7,7 +7,7 @@ use gcloud_sdk::{ }; use tracing::debug; -const ALLOWED_PUBKEYS: &[&str] = &[ +const _ALLOWED_PUBKEYS: &[&str] = &[ "07ecf9838136fe430fac43fa0860dbc62a0aac0729c5a33df1192ce75e330c9f", // Bryan "89ef92b9ebe6dc1e4ea398f6477f227e95429627b0a33dc89b640e137b256be5", // Daniel "e8ad7c13ba55ba0a04c23fc09edce74ad7a8dddc059dc2e274ff63bc2e047782", // Daphne @@ -50,9 +50,9 @@ impl GooglePubSubClient { impl PublishEvents for GooglePubSubClient { async fn publish_events( &mut self, - follow_changes: Vec, + notification_messages: Vec, ) -> Result<(), PublisherError> { - let pubsub_messages: Result, PublisherError> = follow_changes + let pubsub_messages: Result, PublisherError> = notification_messages .iter() // .filter(|message| { // // TODO: Temporary filter while developing this service diff --git a/src/main.rs b/src/main.rs index ba5f0fb..baa8d10 100644 --- a/src/main.rs +++ b/src/main.rs @@ -172,6 +172,7 @@ async fn start(settings: Settings) -> Result<()> { settings.tcp_importer_port, event_sender, cancellation_token.clone(), + 5, ) .await?; diff --git a/src/tcp_importer.rs b/src/tcp_importer.rs index 317fb34..91a2b20 100644 --- a/src/tcp_importer.rs +++ b/src/tcp_importer.rs @@ -1,10 +1,13 @@ use anyhow::{Context, Result}; use nostr_sdk::prelude::*; +use std::net::IpAddr; use std::net::SocketAddr; +use std::sync::Arc; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, TcpStream}; use tokio::sync::broadcast::Sender; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::{error, info}; @@ -13,13 +16,17 @@ pub async fn start_tcp_importer( tcp_port: u16, event_tx: Sender>, cancellation_token: CancellationToken, + max_connections: usize, ) -> Result<()> { let address = SocketAddr::from(([0, 0, 0, 0], tcp_port)); let listener = TcpListener::bind(&address) .await .context(format!("Error opening TCP listener on port {tcp_port}"))?; - info!("Listening for tcp connections on {}", address); + info!("Listening for TCP connections on {}", address); + + // Semaphore with the maximum number of allowed concurrent connections + let semaphore = Arc::new(Semaphore::new(max_connections)); task_tracker.spawn(async move { loop { @@ -29,11 +36,31 @@ pub async fn start_tcp_importer( break; } - Ok((stream, _)) = listener.accept() => { - let tx = event_tx.clone(); - let cancel_token = cancellation_token.clone(); + result = listener.accept() => { + match result { + Ok((stream, addr)) => { + if is_local_address(&addr) { + let tx = event_tx.clone(); + let cancel_token = cancellation_token.clone(); + let semaphore_clone = semaphore.clone(); + let permit = semaphore_clone.acquire_owned().await; - tokio::spawn(handle_connection(stream, tx, cancel_token)); + match permit { + Ok(permit) => { + tokio::spawn(handle_connection_with_permit(stream, tx, cancel_token, permit)); + } + Err(e) => { + error!("Failed to acquire semaphore permit: {}", e); + } + } + } else { + info!("Ignoring connection from non-local address: {}", addr); + } + } + Err(e) => { + error!("Failed to accept connection: {}", e); + } + } } } } @@ -44,9 +71,26 @@ pub async fn start_tcp_importer( Ok(()) } +fn is_local_address(addr: &SocketAddr) -> bool { + match addr.ip() { + IpAddr::V4(ipv4) => ipv4.is_loopback(), // Checks if in 127.0.0.0/8 + IpAddr::V6(ipv6) => ipv6.is_loopback(), // Checks if ::1 + } +} + +async fn handle_connection_with_permit( + stream: TcpStream, + tx: Sender>, + cancellation_token: CancellationToken, + _permit: OwnedSemaphorePermit, +) { + handle_connection(stream, tx, cancellation_token).await; + // Here the _permit is dropped +} + // Handle the incoming connection and read jsonl contact events async fn handle_connection( - stream: tokio::net::TcpStream, + stream: TcpStream, event_tx: Sender>, cancellation_token: CancellationToken, ) {