diff --git a/src/network/Network.cpp b/src/network/Network.cpp index 04c8dbec..be3f8d33 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -164,15 +164,16 @@ void Network::update() { } ++socketiter; } - auto serveriter = servers.begin(); - while (serveriter != servers.end()) { - auto server = serveriter->second.get(); - if (!server->isOpen()) { - serveriter = servers.erase(serveriter); - continue; - } - ++serveriter; + } + auto serveriter = servers.begin(); + while (serveriter != servers.end()) { + auto server = serveriter->second.get(); + if (!server->isOpen()) { + serveriter = servers.erase(serveriter); + continue; } + server->update(); + ++serveriter; } } diff --git a/src/network/Network.hpp b/src/network/Network.hpp index a6c2fca2..047e07b4 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -37,6 +37,8 @@ namespace network { [[nodiscard]] TransportType getTransportType() const noexcept override { return TransportType::TCP; } + + virtual void setMaxClientsConnected(int count) = 0; }; class UdpServer : public Server { diff --git a/src/network/Sockets.cpp b/src/network/Sockets.cpp index 5152028b..9e145da0 100644 --- a/src/network/Sockets.cpp +++ b/src/network/Sockets.cpp @@ -304,6 +304,7 @@ class SocketTcpServer : public TcpServer { bool open = true; std::unique_ptr thread = nullptr; int port; + int maxConnected = -1; public: SocketTcpServer(u64id_t id, Network* network, SOCKET descriptor, int port) : id(id), network(network), descriptor(descriptor), port(port) {} @@ -312,6 +313,22 @@ public: closeSocket(); } + void setMaxClientsConnected(int count) override { + maxConnected = count; + } + + void update() override { + std::vector clients; + for (u64id_t cid : this->clients) { + if (auto client = network->getConnection(cid, true)) { + if (client->getState() != ConnectionState::CLOSED) { + clients.emplace_back(cid); + } + } + } + std::swap(clients, this->clients); + } + void startListen(ConnectCallback handler) override { thread = std::make_unique([this, handler]() { while (open) { @@ -328,6 +345,11 @@ public: close(); break; } + if (maxConnected >= 0 && clients.size() >= maxConnected) { + logger.info() << "refused connection attempt from " << to_string(address); + closesocket(clientDescriptor); + continue; + } logger.info() << "client connected: " << to_string(address); auto socket = std::make_shared( clientDescriptor, address @@ -575,6 +597,8 @@ public: SocketUdpServer::close(); } + void update() override {} + void startListen(ServerDatagramCallback handler) override { callback = std::move(handler); diff --git a/src/network/commons.hpp b/src/network/commons.hpp index 83881c52..ad3c0806 100644 --- a/src/network/commons.hpp +++ b/src/network/commons.hpp @@ -78,6 +78,8 @@ namespace network { class Server { public: virtual ~Server() = default; + + virtual void update() = 0; virtual void close() = 0; virtual bool isOpen() = 0; [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0;