diff --git a/src/network/Network.cpp b/src/network/Network.cpp index a2b98680..8766ff28 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -77,7 +77,171 @@ public: }; -Network::Network(std::unique_ptr http) : http(std::move(http)) { +#ifdef _WIN32 +/// ... +#else +#include +#include +#include +#include +#include +#include +#include +#endif + +#ifndef _WIN32 +static inline int closesocket(int descriptor) noexcept { + return close(descriptor); +} +#endif + +static inline int connectsocket( + int descriptor, const sockaddr* addr, socklen_t len +) noexcept { + return connect(descriptor, addr, len); +} + +static inline int recvsocket( + int descriptor, void* buf, size_t len, int flags +) noexcept { + return recv(descriptor, buf, len, flags); +} + +static inline int sendsocket( + int descriptor, const void* buf, size_t len, int flags +) noexcept { + return send(descriptor, buf, len, flags); +} + +static std::string to_string(const addrinfo* addr) { + if (addr->ai_family == AF_INET) { + auto psai = reinterpret_cast(addr->ai_addr); + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN)) { + return std::string(ip); + } + } else if (addr->ai_family == AF_INET6) { + auto psai = reinterpret_cast(addr->ai_addr); + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop(addr->ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN)) { + return std::string(ip); + } + } + return ""; +} + +class SocketImpl : public Socket { + int descriptor; + bool open = true; + addrinfo* addr; + size_t totalUpload = 0; + size_t totalDownload = 0; +public: + SocketImpl(int descriptor, addrinfo* addr) + : descriptor(descriptor), addr(addr) { + } + + ~SocketImpl() { + closesocket(descriptor); + freeaddrinfo(addr); + } + + int recv(void* buffer, size_t length, bool blocking) override { + int len = recvsocket(descriptor, buffer, length, blocking ? MSG_WAITALL : MSG_DONTWAIT); + if (len == 0) { + int err = errno; + close(); + throw std::runtime_error( + "Read failed [errno=" + std::to_string(err) + + "]: " + std::string(strerror(err)) + ); + } else if (len == -1) { + return 0; + } + totalDownload += len; + return len; + } + + int send(const void* buffer, size_t length) override { + int len = sendsocket(descriptor, buffer, length, 0); + if (len == -1) { + int err = errno; + close(); + throw std::runtime_error( + "Send failed [errno=" + std::to_string(err) + + "]: " + std::string(strerror(err)) + ); + } + totalUpload += len; + return len; + } + + void close() override { + closesocket(descriptor); + open = false; + } + + bool isOpen() const override { + return open; + } + + size_t getTotalUpload() const override { + return totalUpload; + } + + size_t getTotalDownload() const override { + return totalDownload; + } + + static std::shared_ptr connect( + const std::string& address, int port + ) { + addrinfo hints {}; + + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + addrinfo* addrinfo; + if (int res = getaddrinfo( + address.c_str(), std::to_string(port).c_str(), &hints, &addrinfo + )) { + throw std::runtime_error(gai_strerror(res)); + } + int descriptor = socket( + addrinfo->ai_family, addrinfo->ai_socktype, addrinfo->ai_protocol + ); + if (descriptor == -1) { + freeaddrinfo(addrinfo); + throw std::runtime_error("Could not create socket"); + } + int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen); + if (res == -1) { + closesocket(descriptor); + freeaddrinfo(addrinfo); + + int err = errno; + throw std::runtime_error( + "Connect failed [errno=" + std::to_string(err) + + "]: " + std::string(strerror(err)) + ); + } + logger.info() << "connected to " << address << " [" + << to_string(addrinfo) << ":" << port << "]"; + return std::make_shared(descriptor, addrinfo); + } +}; + +class SocketTcp : public Tcp { +public: + SocketTcp() {}; + + std::shared_ptr connect(const std::string& address, int port) override { + return SocketImpl::connect(address, port); + } +}; + +Network::Network(std::unique_ptr http, std::unique_ptr tcp) + : http(std::move(http)), tcp(std::move(tcp)) { } Network::~Network() = default; @@ -88,15 +252,30 @@ void Network::httpGet( http->get(url, onResponse, onReject); } +std::shared_ptr Network::connect(const std::string& address, int port) { + auto socket = tcp->connect(address, port); + connections.push_back(socket); + return socket; +} + size_t Network::getTotalUpload() const { - return http->getTotalUpload(); + size_t totalUpload = 0; + for (const auto& socket : connections) { + totalUpload += socket->getTotalUpload(); + } + return http->getTotalUpload() + totalUpload; } size_t Network::getTotalDownload() const { - return http->getTotalDownload(); + size_t totalDownload = 0; + for (const auto& socket : connections) { + totalDownload += socket->getTotalDownload(); + } + return http->getTotalDownload() + totalDownload; } std::unique_ptr Network::create(const NetworkSettings& settings) { auto http = CurlHttp::create(); - return std::make_unique(std::move(http)); + auto tcp = std::make_unique(); + return std::make_unique(std::move(http), std::move(tcp)); } diff --git a/src/network/Network.hpp b/src/network/Network.hpp index fbf77926..099d8df7 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -2,7 +2,6 @@ #include #include -#include #include "typedefs.hpp" #include "settings.hpp" @@ -25,10 +24,32 @@ namespace network { virtual size_t getTotalDownload() const = 0; }; + class Socket { + public: + virtual int recv(void* buffer, size_t length, bool blocking) = 0; + virtual int send(const void* buffer, size_t length) = 0; + virtual void close() = 0; + virtual bool isOpen() const = 0; + + virtual size_t getTotalUpload() const = 0; + virtual size_t getTotalDownload() const = 0; + }; + + class Tcp { + public: + virtual ~Tcp() {} + + virtual std::shared_ptr connect( + const std::string& address, int port + ) = 0; + }; + class Network { std::unique_ptr http; + std::unique_ptr tcp; + std::vector> connections; public: - Network(std::unique_ptr http); + Network(std::unique_ptr http, std::unique_ptr tcp); ~Network(); void httpGet( @@ -37,6 +58,8 @@ namespace network { OnReject onReject = nullptr ); + std::shared_ptr connect(const std::string& address, int port); + size_t getTotalUpload() const; size_t getTotalDownload() const; diff --git a/test/network/curltest.cpp b/test/network/curltest.cpp index ac4d04dd..79ab3332 100644 --- a/test/network/curltest.cpp +++ b/test/network/curltest.cpp @@ -18,6 +18,16 @@ TEST(curltest, curltest) { std::cout << value << std::endl; } ); + if (false) { + auto socket = network->connect("localhost", 8000); + const char* string = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"; + socket->send(string, strlen(string)); + char data[1024]; + + int len = socket->recv(data, 1024, true); + std::cout << len << " " << std::string(data, len) << std::endl; + } + std::cout << "upload: " << network->getTotalUpload() << " B" << std::endl; std::cout << "download: " << network->getTotalDownload() << " B" << std::endl; }