From 18bdce52df15074903667d010a8362b362bf9b51 Mon Sep 17 00:00:00 2001 From: MihailRis Date: Tue, 26 Nov 2024 17:12:24 +0300 Subject: [PATCH] make socket non-blocking --- src/network/Network.cpp | 36 +++++++++++++++++++++++++++++++----- src/network/Network.hpp | 7 +++++-- test/network/curltest.cpp | 2 +- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/network/Network.cpp b/src/network/Network.cpp index aec84372..8d258d43 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -182,6 +182,7 @@ public: #include #include #include +#include #endif #ifndef _WIN32 @@ -309,6 +310,22 @@ public: freeaddrinfo(addrinfo); throw std::runtime_error("Could not create socket"); } +#ifdef _WIN32 + u_long mode = 1; + auto result = ioctlsocket(descriptor, FIONBIO, &mode); + if (result != NO_ERROR) { + throw std::runtime_error( + "Could not set to non-blocking mode [errno=" + std::to_string(err) + + "]: " + std::string(strerror(err)) + ); + } +#else + if (fcntl(descriptor, F_SETFL, O_NONBLOCK) < 0) { + freeaddrinfo(addrinfo); + closesocket(descriptor); + throw std::runtime_error("Failed to make socket non-blocking"); + } +#endif int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen); if (res == -1) { closesocket(descriptor); @@ -341,15 +358,24 @@ void Network::get( requests->get(url, onResponse, onReject, maxSize); } -std::shared_ptr Network::connect(const std::string& address, int port) { +Socket* Network::getConnection(u64id_t id) const { + const auto& found = connections.find(id); + if (found == connections.end()) { + return nullptr; + } + return found->second.get(); +} + +u64id_t Network::connect(const std::string& address, int port) { auto socket = SocketImpl::connect(address, port); - connections.push_back(socket); - return socket; + u64id_t id = nextConnection++; + connections[id] = std::move(socket); + return id; } size_t Network::getTotalUpload() const { size_t totalUpload = 0; - for (const auto& socket : connections) { + for (const auto& [_, socket] : connections) { totalUpload += socket->getTotalUpload(); } return requests->getTotalUpload() + totalUpload; @@ -357,7 +383,7 @@ size_t Network::getTotalUpload() const { size_t Network::getTotalDownload() const { size_t totalDownload = 0; - for (const auto& socket : connections) { + for (const auto& [_, socket] : connections) { totalDownload += socket->getTotalDownload(); } return requests->getTotalDownload() + totalDownload; diff --git a/src/network/Network.hpp b/src/network/Network.hpp index 62785ad2..450b8640 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -40,7 +40,8 @@ namespace network { class Network { std::unique_ptr requests; - std::vector> connections; + std::unordered_map> connections; + u64id_t nextConnection = 1; public: Network(std::unique_ptr requests); ~Network(); @@ -52,7 +53,9 @@ namespace network { long maxSize=0 ); - std::shared_ptr connect(const std::string& address, int port); + Socket* getConnection(u64id_t id) const; + + u64id_t connect(const std::string& address, int port); size_t getTotalUpload() const; size_t getTotalDownload() const; diff --git a/test/network/curltest.cpp b/test/network/curltest.cpp index 0713376e..5ce86b7b 100644 --- a/test/network/curltest.cpp +++ b/test/network/curltest.cpp @@ -19,7 +19,7 @@ TEST(curltest, curltest) { }, [](auto){} ); if (false) { - auto socket = network->connect("localhost", 8000); + auto socket = network->getConnection(network->connect("localhost", 8000)); const char* string = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"; socket->send(string, strlen(string)); char data[1024];