make socket non-blocking

This commit is contained in:
MihailRis 2024-11-26 17:12:24 +03:00
parent e8b674ca65
commit 18bdce52df
3 changed files with 37 additions and 8 deletions

View File

@ -182,6 +182,7 @@ public:
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <fcntl.h>
#endif
#ifndef _WIN32
@ -309,6 +310,22 @@ public:
freeaddrinfo(addrinfo);
throw std::runtime_error("Could not create socket");
}
#ifdef _WIN32
u_long mode = 1;
auto result = ioctlsocket(descriptor, FIONBIO, &mode);
if (result != NO_ERROR) {
throw std::runtime_error(
"Could not set to non-blocking mode [errno=" + std::to_string(err) +
"]: " + std::string(strerror(err))
);
}
#else
if (fcntl(descriptor, F_SETFL, O_NONBLOCK) < 0) {
freeaddrinfo(addrinfo);
closesocket(descriptor);
throw std::runtime_error("Failed to make socket non-blocking");
}
#endif
int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen);
if (res == -1) {
closesocket(descriptor);
@ -341,15 +358,24 @@ void Network::get(
requests->get(url, onResponse, onReject, maxSize);
}
std::shared_ptr<Socket> Network::connect(const std::string& address, int port) {
Socket* Network::getConnection(u64id_t id) const {
const auto& found = connections.find(id);
if (found == connections.end()) {
return nullptr;
}
return found->second.get();
}
u64id_t Network::connect(const std::string& address, int port) {
auto socket = SocketImpl::connect(address, port);
connections.push_back(socket);
return socket;
u64id_t id = nextConnection++;
connections[id] = std::move(socket);
return id;
}
size_t Network::getTotalUpload() const {
size_t totalUpload = 0;
for (const auto& socket : connections) {
for (const auto& [_, socket] : connections) {
totalUpload += socket->getTotalUpload();
}
return requests->getTotalUpload() + totalUpload;
@ -357,7 +383,7 @@ size_t Network::getTotalUpload() const {
size_t Network::getTotalDownload() const {
size_t totalDownload = 0;
for (const auto& socket : connections) {
for (const auto& [_, socket] : connections) {
totalDownload += socket->getTotalDownload();
}
return requests->getTotalDownload() + totalDownload;

View File

@ -40,7 +40,8 @@ namespace network {
class Network {
std::unique_ptr<Requests> requests;
std::vector<std::shared_ptr<Socket>> connections;
std::unordered_map<u64id_t, std::shared_ptr<Socket>> connections;
u64id_t nextConnection = 1;
public:
Network(std::unique_ptr<Requests> requests);
~Network();
@ -52,7 +53,9 @@ namespace network {
long maxSize=0
);
std::shared_ptr<Socket> connect(const std::string& address, int port);
Socket* getConnection(u64id_t id) const;
u64id_t connect(const std::string& address, int port);
size_t getTotalUpload() const;
size_t getTotalDownload() const;

View File

@ -19,7 +19,7 @@ TEST(curltest, curltest) {
}, [](auto){}
);
if (false) {
auto socket = network->connect("localhost", 8000);
auto socket = network->getConnection(network->connect("localhost", 8000));
const char* string = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
socket->send(string, strlen(string));
char data[1024];