add Network.connect (WIP)

This commit is contained in:
MihailRis 2024-11-11 19:55:23 +03:00
parent 4328c83c79
commit dc84fe1f07
3 changed files with 218 additions and 6 deletions

View File

@ -77,7 +77,171 @@ public:
};
Network::Network(std::unique_ptr<Http> http) : http(std::move(http)) {
#ifdef _WIN32
/// ...
#else
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <unistd.h>
#endif
#ifndef _WIN32
static inline int closesocket(int descriptor) noexcept {
return close(descriptor);
}
#endif
static inline int connectsocket(
int descriptor, const sockaddr* addr, socklen_t len
) noexcept {
return connect(descriptor, addr, len);
}
static inline int recvsocket(
int descriptor, void* buf, size_t len, int flags
) noexcept {
return recv(descriptor, buf, len, flags);
}
static inline int sendsocket(
int descriptor, const void* buf, size_t len, int flags
) noexcept {
return send(descriptor, buf, len, flags);
}
static std::string to_string(const addrinfo* addr) {
if (addr->ai_family == AF_INET) {
auto psai = reinterpret_cast<sockaddr_in*>(addr->ai_addr);
char ip[INET_ADDRSTRLEN];
if (inet_ntop(addr->ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN)) {
return std::string(ip);
}
} else if (addr->ai_family == AF_INET6) {
auto psai = reinterpret_cast<sockaddr_in6*>(addr->ai_addr);
char ip[INET6_ADDRSTRLEN];
if (inet_ntop(addr->ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN)) {
return std::string(ip);
}
}
return "";
}
class SocketImpl : public Socket {
int descriptor;
bool open = true;
addrinfo* addr;
size_t totalUpload = 0;
size_t totalDownload = 0;
public:
SocketImpl(int descriptor, addrinfo* addr)
: descriptor(descriptor), addr(addr) {
}
~SocketImpl() {
closesocket(descriptor);
freeaddrinfo(addr);
}
int recv(void* buffer, size_t length, bool blocking) override {
int len = recvsocket(descriptor, buffer, length, blocking ? MSG_WAITALL : MSG_DONTWAIT);
if (len == 0) {
int err = errno;
close();
throw std::runtime_error(
"Read failed [errno=" + std::to_string(err) +
"]: " + std::string(strerror(err))
);
} else if (len == -1) {
return 0;
}
totalDownload += len;
return len;
}
int send(const void* buffer, size_t length) override {
int len = sendsocket(descriptor, buffer, length, 0);
if (len == -1) {
int err = errno;
close();
throw std::runtime_error(
"Send failed [errno=" + std::to_string(err) +
"]: " + std::string(strerror(err))
);
}
totalUpload += len;
return len;
}
void close() override {
closesocket(descriptor);
open = false;
}
bool isOpen() const override {
return open;
}
size_t getTotalUpload() const override {
return totalUpload;
}
size_t getTotalDownload() const override {
return totalDownload;
}
static std::shared_ptr<SocketImpl> connect(
const std::string& address, int port
) {
addrinfo hints {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
addrinfo* addrinfo;
if (int res = getaddrinfo(
address.c_str(), std::to_string(port).c_str(), &hints, &addrinfo
)) {
throw std::runtime_error(gai_strerror(res));
}
int descriptor = socket(
addrinfo->ai_family, addrinfo->ai_socktype, addrinfo->ai_protocol
);
if (descriptor == -1) {
freeaddrinfo(addrinfo);
throw std::runtime_error("Could not create socket");
}
int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen);
if (res == -1) {
closesocket(descriptor);
freeaddrinfo(addrinfo);
int err = errno;
throw std::runtime_error(
"Connect failed [errno=" + std::to_string(err) +
"]: " + std::string(strerror(err))
);
}
logger.info() << "connected to " << address << " ["
<< to_string(addrinfo) << ":" << port << "]";
return std::make_shared<SocketImpl>(descriptor, addrinfo);
}
};
class SocketTcp : public Tcp {
public:
SocketTcp() {};
std::shared_ptr<Socket> connect(const std::string& address, int port) override {
return SocketImpl::connect(address, port);
}
};
Network::Network(std::unique_ptr<Http> http, std::unique_ptr<Tcp> tcp)
: http(std::move(http)), tcp(std::move(tcp)) {
}
Network::~Network() = default;
@ -88,15 +252,30 @@ void Network::httpGet(
http->get(url, onResponse, onReject);
}
std::shared_ptr<Socket> Network::connect(const std::string& address, int port) {
auto socket = tcp->connect(address, port);
connections.push_back(socket);
return socket;
}
size_t Network::getTotalUpload() const {
return http->getTotalUpload();
size_t totalUpload = 0;
for (const auto& socket : connections) {
totalUpload += socket->getTotalUpload();
}
return http->getTotalUpload() + totalUpload;
}
size_t Network::getTotalDownload() const {
return http->getTotalDownload();
size_t totalDownload = 0;
for (const auto& socket : connections) {
totalDownload += socket->getTotalDownload();
}
return http->getTotalDownload() + totalDownload;
}
std::unique_ptr<Network> Network::create(const NetworkSettings& settings) {
auto http = CurlHttp::create();
return std::make_unique<Network>(std::move(http));
auto tcp = std::make_unique<SocketTcp>();
return std::make_unique<Network>(std::move(http), std::move(tcp));
}

View File

@ -2,7 +2,6 @@
#include <memory>
#include <vector>
#include <functional>
#include "typedefs.hpp"
#include "settings.hpp"
@ -25,10 +24,32 @@ namespace network {
virtual size_t getTotalDownload() const = 0;
};
class Socket {
public:
virtual int recv(void* buffer, size_t length, bool blocking) = 0;
virtual int send(const void* buffer, size_t length) = 0;
virtual void close() = 0;
virtual bool isOpen() const = 0;
virtual size_t getTotalUpload() const = 0;
virtual size_t getTotalDownload() const = 0;
};
class Tcp {
public:
virtual ~Tcp() {}
virtual std::shared_ptr<Socket> connect(
const std::string& address, int port
) = 0;
};
class Network {
std::unique_ptr<Http> http;
std::unique_ptr<Tcp> tcp;
std::vector<std::shared_ptr<Socket>> connections;
public:
Network(std::unique_ptr<Http> http);
Network(std::unique_ptr<Http> http, std::unique_ptr<Tcp> tcp);
~Network();
void httpGet(
@ -37,6 +58,8 @@ namespace network {
OnReject onReject = nullptr
);
std::shared_ptr<Socket> connect(const std::string& address, int port);
size_t getTotalUpload() const;
size_t getTotalDownload() const;

View File

@ -18,6 +18,16 @@ TEST(curltest, curltest) {
std::cout << value << std::endl;
}
);
if (false) {
auto socket = network->connect("localhost", 8000);
const char* string = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
socket->send(string, strlen(string));
char data[1024];
int len = socket->recv(data, 1024, true);
std::cout << len << " " << std::string(data, len) << std::endl;
}
std::cout << "upload: " << network->getTotalUpload() << " B" << std::endl;
std::cout << "download: " << network->getTotalDownload() << " B" << std::endl;
}