diff --git a/src/main.cpp b/src/main.cpp index 4fa1a7f..c4e8ca4 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,3 @@ -#include #include "common.h" #include "tcp.h" #include "epoll_event.h" @@ -57,9 +56,7 @@ int main() { signal_handle.add_signal(SIGINT, signal_handler); signal_handle.ignore_signal(SIGPIPE); - std::unordered_map client_ip_mapper; /* client ip 到 连接数量的映射 */ std::unordered_map client_process_mapper; /* client 进程到 ip 的映射 */ - std::set client_counter; /* 当前建立连接的 client 数量 */ bool terminate = false; while (!terminate) { @@ -80,26 +77,10 @@ int main() { utility::debug_info(std::string("Accept client: " + std::to_string(client_ip))); /* 限流 */ - if (configure::MAX_CONN_PER_IP > 0 && configure::MAX_CONN_PER_IP <= client_ip_mapper[client_ip]) { - /* 该 IP 建立的连接数过多 */ - char buf[1024] = {0}; - snprintf(buf, sizeof(buf), "%d There are too many connections from your address.\r\n", - ftp_response_code::kFTP_IP_LIMIT); - tcp::send_data(connect_fd, buf, strlen(buf)); - close(connect_fd); - continue; - } else if (configure::MAX_CLIENT_NUM > 0 && configure::MAX_CLIENT_NUM <= client_counter.size()) { - /* 建立连接的 client 过多 */ - char buf[1024] = {0}; - snprintf(buf, sizeof(buf), "%d There are too many connected users, please try later.\r\n", - ftp_response_code::kFTP_TOO_MANY_USERS); - tcp::send_data(connect_fd, buf, strlen(buf)); + if (tcp_server.limit_client_crowding(connect_fd, client_ip)) { close(connect_fd); continue; } - /* 更新计数信息*/ - client_counter.insert(client_ip); - ++client_ip_mapper[client_ip]; /* 创建服务进程,为 client 服务 */ pid_t pid = fork(); @@ -140,10 +121,7 @@ int main() { while ((pid = waitpid(-1, &stat, WNOHANG)) > 0) { /* 进程退出后的善后处理 */ utility::debug_info(std::string("Subprocess: ") + std::to_string(pid) + " exited"); - unsigned int client_ip = client_process_mapper[pid]; - if (client_ip_mapper[client_ip] > 0) { - --client_ip_mapper[client_ip] == 0 ? client_counter.erase(client_ip) : 1; - } + tcp_server.on_a_client_exit(client_process_mapper[pid]); } break; } diff --git a/src/tcp.cpp b/src/tcp.cpp index 62e238b..fd2d060 100644 --- a/src/tcp.cpp +++ b/src/tcp.cpp @@ -6,6 +6,8 @@ #include "tcp.h" #include "utility.h" #include "common.h" +#include "configure.h" +#include "ftp_codes.h" static std::string g_local_ip; @@ -60,6 +62,36 @@ int CLTCPServer::accept_client(struct sockaddr_in &client_address) { return accept(m_listen_fd, (struct sockaddr *) &client_address, &client_addr_len); } +bool CLTCPServer::limit_client_crowding(int connect_fd, unsigned int client_ip) { + /* 限流 */ + if (configure::MAX_CONN_PER_IP > 0 && configure::MAX_CONN_PER_IP <= m_client_ip_mapper[client_ip]) { + /* 该 IP 建立的连接数过多 */ + char buf[1024] = {0}; + snprintf(buf, sizeof(buf), "%d There are too many connections from your address.\r\n", + ftp_response_code::kFTP_IP_LIMIT); + tcp::send_data(connect_fd, buf, strlen(buf)); + return true; + } else if (configure::MAX_CLIENT_NUM > 0 && configure::MAX_CLIENT_NUM <= m_client_counter.size()) { + /* 建立连接的 client 过多 */ + char buf[1024] = {0}; + snprintf(buf, sizeof(buf), "%d There are too many connected users, please try later.\r\n", + ftp_response_code::kFTP_TOO_MANY_USERS); + tcp::send_data(connect_fd, buf, strlen(buf)); + return true; + } + /* 更新计数信息*/ + m_client_counter.insert(client_ip); + ++m_client_ip_mapper[client_ip]; + + return false; +} + +void CLTCPServer::on_a_client_exit(unsigned int client_ip) { + if (m_client_ip_mapper[client_ip] > 0) { + --m_client_ip_mapper[client_ip] == 0 ? m_client_counter.erase(client_ip) : 1; + } +} + namespace tcp { void close_fd(int fd) { if (close(fd) < 0) diff --git a/src/tcp.h b/src/tcp.h index ddc18d9..b1148cd 100644 --- a/src/tcp.h +++ b/src/tcp.h @@ -6,6 +6,8 @@ #define FTPD_TCP_H #include +#include +#include #include "common.h" #include "utility.h" @@ -17,14 +19,22 @@ class CLTCPServer { } + int start_listen(); int accept_client(struct sockaddr_in &client_address); + bool limit_client_crowding(int connect_fd, unsigned int client_ip); /* 限流*/ + void on_a_client_exit(unsigned int client_ip); + private: unsigned int m_port; std::string m_host; int m_listen_fd; + + /* 限流 */ + std::unordered_map m_client_ip_mapper; /* client ip 到 连接数量的映射 */ + std::set m_client_counter; /* 当前建立连接的 client 数量 */ }; namespace tcp {