diff --git a/src/logic/scripting/lua/libs/libnetwork.cpp b/src/logic/scripting/lua/libs/libnetwork.cpp index a78f17db..4201348a 100644 --- a/src/logic/scripting/lua/libs/libnetwork.cpp +++ b/src/logic/scripting/lua/libs/libnetwork.cpp @@ -155,7 +155,7 @@ static int l_post(lua::State* L, network::Network& network) { static int l_close(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { connection->close(true); } return 0; @@ -163,7 +163,7 @@ static int l_close(lua::State* L, network::Network& network) { static int l_closeserver(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto server = network.getServer(id)) { + if (auto server = network.getServer(id, false)) { server->close(); } return 0; @@ -171,7 +171,7 @@ static int l_closeserver(lua::State* L, network::Network& network) { static int l_send(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - auto connection = network.getConnection(id); + auto connection = network.getConnection(id, false); if (connection == nullptr || connection->getState() == network::ConnectionState::CLOSED) { return 0; @@ -201,7 +201,7 @@ static int l_send(lua::State* L, network::Network& network) { static int l_udp_server_send_to(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto server = network.getServer(id)) { + if (auto server = network.getServer(id, false)) { if (server->getTransportType() != network::TransportType::UDP) throw std::runtime_error("the server must work on UDP transport"); @@ -238,7 +238,7 @@ static int l_recv(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); int length = lua::tointeger(L, 2); - auto connection = engine->getNetwork().getConnection(id); + auto connection = engine->getNetwork().getConnection(id, false); if (connection == nullptr || connection->getTransportType() != network::TransportType::TCP) { return 0; @@ -268,7 +268,7 @@ static int l_recv(lua::State* L, network::Network& network) { static int l_available(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { return lua::pushinteger(L, dynamic_cast(connection)->available()); } @@ -345,7 +345,7 @@ static int l_open_udp(lua::State* L, network::Network& network) { static int l_is_alive(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { return lua::pushboolean( L, connection->getState() != network::ConnectionState::CLOSED || @@ -360,7 +360,7 @@ static int l_is_alive(lua::State* L, network::Network& network) { static int l_is_connected(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { return lua::pushboolean( L, connection->getState() == network::ConnectionState::CONNECTED ); @@ -370,7 +370,7 @@ static int l_is_connected(lua::State* L, network::Network& network) { static int l_get_address(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { lua::pushstring(L, connection->getAddress()); lua::pushinteger(L, connection->getPort()); return 2; @@ -380,7 +380,7 @@ static int l_get_address(lua::State* L, network::Network& network) { static int l_is_serveropen(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto server = network.getServer(id)) { + if (auto server = network.getServer(id, false)) { return lua::pushboolean(L, server->isOpen()); } return lua::pushboolean(L, false); @@ -388,7 +388,7 @@ static int l_is_serveropen(lua::State* L, network::Network& network) { static int l_get_serverport(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto server = network.getServer(id)) { + if (auto server = network.getServer(id, false)) { return lua::pushinteger(L, server->getPort()); } return 0; @@ -405,7 +405,7 @@ static int l_get_total_download(lua::State* L, network::Network& network) { static int l_set_nodelay(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); bool noDelay = lua::toboolean(L, 2); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { if (connection->getTransportType() == network::TransportType::TCP) { dynamic_cast(connection)->setNoDelay(noDelay); } @@ -415,7 +415,7 @@ static int l_set_nodelay(lua::State* L, network::Network& network) { static int l_is_nodelay(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); - if (auto connection = network.getConnection(id)) { + if (auto connection = network.getConnection(id, false)) { if (connection->getTransportType() == network::TransportType::TCP) { bool noDelay = dynamic_cast(connection)->isNoDelay(); return lua::pushboolean(L, noDelay); diff --git a/src/network/Curl.cpp b/src/network/Curl.cpp new file mode 100644 index 00000000..bd40a478 --- /dev/null +++ b/src/network/Curl.cpp @@ -0,0 +1,239 @@ +#include "commons.hpp" + +#include "debug/Logger.hpp" + +#define NOMINMAX +#include +#include + +using namespace network; + +static debug::Logger logger("curl"); + +inline constexpr int HTTP_OK = 200; +inline constexpr int HTTP_BAD_GATEWAY = 502; + +static size_t write_callback( + char* ptr, size_t size, size_t nmemb, void* userdata +) { + auto& buffer = *reinterpret_cast*>(userdata); + size_t psize = buffer.size(); + buffer.resize(psize + size * nmemb); + std::memcpy(buffer.data() + psize, ptr, size * nmemb); + return size * nmemb; +} + +enum class RequestType { + GET, POST +}; + +struct Request { + RequestType type; + std::string url; + OnResponse onResponse; + OnReject onReject; + long maxSize; + bool followLocation = false; + std::string data; + std::vector headers; +}; + +class CurlRequests : public Requests { + CURLM* multiHandle; + CURL* curl; + + size_t totalUpload = 0; + size_t totalDownload = 0; + + OnResponse onResponse; + OnReject onReject; + std::vector buffer; + std::string url; + + std::queue requests; +public: + CurlRequests(CURLM* multiHandle, CURL* curl) + : multiHandle(multiHandle), curl(curl) { + } + + virtual ~CurlRequests() { + curl_multi_remove_handle(multiHandle, curl); + curl_easy_cleanup(curl); + curl_multi_cleanup(multiHandle); + } + void get( + const std::string& url, + OnResponse onResponse, + OnReject onReject, + std::vector headers, + long maxSize + ) override { + Request request { + RequestType::GET, + url, + onResponse, + onReject, + maxSize, + true, + "", + std::move(headers)}; + processRequest(std::move(request)); + } + + void post( + const std::string& url, + const std::string& data, + OnResponse onResponse, + OnReject onReject=nullptr, + std::vector headers = {}, + long maxSize=0 + ) override { + Request request { + RequestType::POST, + url, + onResponse, + onReject, + maxSize, + false, + "", + std::move(headers)}; + request.data = data; + processRequest(std::move(request)); + } + + void processRequest(Request request) { + if (!url.empty()) { + requests.push(request); + return; + } + onResponse = request.onResponse; + onReject = request.onReject; + url = request.url; + + buffer.clear(); + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POST, request.type == RequestType::POST); + + curl_slist* hs = NULL; + + for (const auto& header : request.headers) { + hs = curl_slist_append(hs, header.c_str()); + } + + switch (request.type) { + case RequestType::GET: + break; + case RequestType::POST: + hs = curl_slist_append(hs, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request.data.length()); + curl_easy_setopt(curl, CURLOPT_COPYPOSTFIELDS, request.data.c_str()); + break; + default: + throw std::runtime_error("not implemented"); + } + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, request.followLocation); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buffer); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "curl/7.81.0"); + if (request.maxSize == 0) { + curl_easy_setopt( + curl, CURLOPT_MAXFILESIZE, std::numeric_limits::max() + ); + } else { + curl_easy_setopt(curl, CURLOPT_MAXFILESIZE, request.maxSize); + } + curl_multi_add_handle(multiHandle, curl); + int running; + CURLMcode res = curl_multi_perform(multiHandle, &running); + if (res != CURLM_OK) { + auto message = curl_multi_strerror(res); + logger.error() << message << " (" << url << ")"; + if (onReject) { + onReject(HTTP_BAD_GATEWAY, {}); + } + url = ""; + } + } + + void update() override { + int messagesLeft; + int running; + CURLMsg* msg; + CURLMcode res = curl_multi_perform(multiHandle, &running); + if (res != CURLM_OK) { + auto message = curl_multi_strerror(res); + logger.error() << message << " (" << url << ")"; + if (onReject) { + onReject(HTTP_BAD_GATEWAY, {}); + } + curl_multi_remove_handle(multiHandle, curl); + url = ""; + return; + } + if ((msg = curl_multi_info_read(multiHandle, &messagesLeft)) != NULL) { + if(msg->msg == CURLMSG_DONE) { + curl_multi_remove_handle(multiHandle, curl); + } + int response; + curl_easy_getinfo(msg->easy_handle, CURLINFO_RESPONSE_CODE, &response); + if (response == HTTP_OK) { + long size; + if (!curl_easy_getinfo(curl, CURLINFO_REQUEST_SIZE, &size)) { + totalUpload += size; + } + if (!curl_easy_getinfo(curl, CURLINFO_HEADER_SIZE, &size)) { + totalDownload += size; + } + totalDownload += buffer.size(); + if (onResponse) { + onResponse(std::move(buffer)); + } + } else { + logger.error() + << "response code " << response << " (" << url << ")" + << (buffer.empty() + ? "" + : std::to_string(buffer.size()) + " byte(s)"); + totalDownload += buffer.size(); + if (onReject) { + onReject(response, std::move(buffer)); + } + } + url = ""; + } + if (url.empty() && !requests.empty()) { + auto request = std::move(requests.front()); + requests.pop(); + processRequest(std::move(request)); + } + } + + size_t getTotalUpload() const override { + return totalUpload; + } + + size_t getTotalDownload() const override { + return totalDownload; + } + + static std::unique_ptr create() { + auto curl = curl_easy_init(); + if (curl == nullptr) { + throw std::runtime_error("could not initialzie cURL"); + } + auto multiHandle = curl_multi_init(); + if (multiHandle == nullptr) { + curl_easy_cleanup(curl); + throw std::runtime_error("could not initialzie cURL-multi"); + } + return std::make_unique(multiHandle, curl); + } +}; + +namespace network { + std::unique_ptr create_curl_requests() { + return CurlRequests::create(); + } +} diff --git a/src/network/Network.cpp b/src/network/Network.cpp index f4e9db9b..04c8dbec 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -1,869 +1,44 @@ #include "Network.hpp" -#pragma comment(lib, "Ws2_32.lib") - -#define NOMINMAX -#include #include #include #include #include -#ifdef _WIN32 -/// included in curl.h -#else -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using SOCKET = int; -#endif // _WIN32 - #include "debug/Logger.hpp" #include "util/stringutil.hpp" using namespace network; -inline constexpr int HTTP_OK = 200; -inline constexpr int HTTP_BAD_GATEWAY = 502; - static debug::Logger logger("network"); -static size_t write_callback( - char* ptr, size_t size, size_t nmemb, void* userdata -) { - auto& buffer = *reinterpret_cast*>(userdata); - size_t psize = buffer.size(); - buffer.resize(psize + size * nmemb); - std::memcpy(buffer.data() + psize, ptr, size * nmemb); - return size * nmemb; -} +namespace network { + std::unique_ptr create_curl_requests(); -enum class RequestType { - GET, POST -}; - -struct Request { - RequestType type; - std::string url; - OnResponse onResponse; - OnReject onReject; - long maxSize; - bool followLocation = false; - std::string data; - std::vector headers; -}; - -class CurlRequests : public Requests { - CURLM* multiHandle; - CURL* curl; - - size_t totalUpload = 0; - size_t totalDownload = 0; - - OnResponse onResponse; - OnReject onReject; - std::vector buffer; - std::string url; - - std::queue requests; -public: - CurlRequests(CURLM* multiHandle, CURL* curl) - : multiHandle(multiHandle), curl(curl) { - } - - virtual ~CurlRequests() { - curl_multi_remove_handle(multiHandle, curl); - curl_easy_cleanup(curl); - curl_multi_cleanup(multiHandle); - } - void get( - const std::string& url, - OnResponse onResponse, - OnReject onReject, - std::vector headers, - long maxSize - ) override { - Request request { - RequestType::GET, - url, - onResponse, - onReject, - maxSize, - true, - "", - std::move(headers)}; - processRequest(std::move(request)); - } - - void post( - const std::string& url, - const std::string& data, - OnResponse onResponse, - OnReject onReject=nullptr, - std::vector headers = {}, - long maxSize=0 - ) override { - Request request { - RequestType::POST, - url, - onResponse, - onReject, - maxSize, - false, - "", - std::move(headers)}; - request.data = data; - processRequest(std::move(request)); - } - - void processRequest(Request request) { - if (!url.empty()) { - requests.push(request); - return; - } - onResponse = request.onResponse; - onReject = request.onReject; - url = request.url; - - buffer.clear(); - - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_POST, request.type == RequestType::POST); - - curl_slist* hs = NULL; - - for (const auto& header : request.headers) { - hs = curl_slist_append(hs, header.c_str()); - } - - switch (request.type) { - case RequestType::GET: - break; - case RequestType::POST: - hs = curl_slist_append(hs, "Content-Type: application/json"); - curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request.data.length()); - curl_easy_setopt(curl, CURLOPT_COPYPOSTFIELDS, request.data.c_str()); - break; - default: - throw std::runtime_error("not implemented"); - } - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, request.followLocation); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buffer); - curl_easy_setopt(curl, CURLOPT_USERAGENT, "curl/7.81.0"); - if (request.maxSize == 0) { - curl_easy_setopt( - curl, CURLOPT_MAXFILESIZE, std::numeric_limits::max() - ); - } else { - curl_easy_setopt(curl, CURLOPT_MAXFILESIZE, request.maxSize); - } - curl_multi_add_handle(multiHandle, curl); - int running; - CURLMcode res = curl_multi_perform(multiHandle, &running); - if (res != CURLM_OK) { - auto message = curl_multi_strerror(res); - logger.error() << message << " (" << url << ")"; - if (onReject) { - onReject(HTTP_BAD_GATEWAY, {}); - } - url = ""; - } - } - - void update() override { - int messagesLeft; - int running; - CURLMsg* msg; - CURLMcode res = curl_multi_perform(multiHandle, &running); - if (res != CURLM_OK) { - auto message = curl_multi_strerror(res); - logger.error() << message << " (" << url << ")"; - if (onReject) { - onReject(HTTP_BAD_GATEWAY, {}); - } - curl_multi_remove_handle(multiHandle, curl); - url = ""; - return; - } - if ((msg = curl_multi_info_read(multiHandle, &messagesLeft)) != NULL) { - if(msg->msg == CURLMSG_DONE) { - curl_multi_remove_handle(multiHandle, curl); - } - int response; - curl_easy_getinfo(msg->easy_handle, CURLINFO_RESPONSE_CODE, &response); - if (response == HTTP_OK) { - long size; - if (!curl_easy_getinfo(curl, CURLINFO_REQUEST_SIZE, &size)) { - totalUpload += size; - } - if (!curl_easy_getinfo(curl, CURLINFO_HEADER_SIZE, &size)) { - totalDownload += size; - } - totalDownload += buffer.size(); - if (onResponse) { - onResponse(std::move(buffer)); - } - } else { - logger.error() - << "response code " << response << " (" << url << ")" - << (buffer.empty() - ? "" - : std::to_string(buffer.size()) + " byte(s)"); - totalDownload += buffer.size(); - if (onReject) { - onReject(response, std::move(buffer)); - } - } - url = ""; - } - if (url.empty() && !requests.empty()) { - auto request = std::move(requests.front()); - requests.pop(); - processRequest(std::move(request)); - } - } - - size_t getTotalUpload() const override { - return totalUpload; - } - - size_t getTotalDownload() const override { - return totalDownload; - } - - static std::unique_ptr create() { - auto curl = curl_easy_init(); - if (curl == nullptr) { - throw std::runtime_error("could not initialzie cURL"); - } - auto multiHandle = curl_multi_init(); - if (multiHandle == nullptr) { - curl_easy_cleanup(curl); - throw std::runtime_error("could not initialzie cURL-multi"); - } - return std::make_unique(multiHandle, curl); - } -}; - -#ifndef _WIN32 -static inline int closesocket(int descriptor) noexcept { - return close(descriptor); -} -static inline std::runtime_error handle_socket_error(const std::string& message) { - int err = errno; - return std::runtime_error( - message+" [errno=" + std::to_string(err) + "]: " + - std::string(strerror(err)) - ); -} -#else -static inline std::runtime_error handle_socket_error(const std::string& message) { - int errorCode = WSAGetLastError(); - wchar_t* s = nullptr; - size_t size = FormatMessageW( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - nullptr, - errorCode, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)&s, - 0, - nullptr - ); - assert(s != nullptr); - while (size && isspace(s[size-1])) { - s[--size] = 0; - } - auto errorString = util::wstr2str_utf8(std::wstring(s)); - LocalFree(s); - return std::runtime_error(message+" [WSA error=" + - std::to_string(errorCode) + "]: "+errorString); -} -#endif - -static inline int connectsocket( - int descriptor, const sockaddr* addr, socklen_t len -) noexcept { - return connect(descriptor, addr, len); -} - -static inline int recvsocket( - int descriptor, char* buf, size_t len -) noexcept { - return recv(descriptor, buf, len, 0); -} - -static inline int sendsocket( - int descriptor, const char* buf, size_t len, int flags -) noexcept { - return send(descriptor, buf, len, flags); -} - -static std::string to_string(const sockaddr_in& addr, bool port=true) { - char ip[INET_ADDRSTRLEN]; - if (inet_ntop(AF_INET, &(addr.sin_addr), ip, INET_ADDRSTRLEN)) { - return std::string(ip) + - (port ? (":" + std::to_string(htons(addr.sin_port))) : ""); - } - return ""; -} - -class SocketTcpConnection : public TcpConnection { - SOCKET descriptor; - sockaddr_in addr; - size_t totalUpload = 0; - size_t totalDownload = 0; - ConnectionState state = ConnectionState::INITIAL; - std::unique_ptr thread = nullptr; - std::vector readBatch; - util::Buffer buffer; - std::mutex mutex; - - void connectSocket() { - state = ConnectionState::CONNECTING; - logger.info() << "connecting to " << to_string(addr); - int res = connectsocket(descriptor, (const sockaddr*)&addr, sizeof(sockaddr_in)); - if (res < 0) { - auto error = handle_socket_error("Connect failed"); - closesocket(descriptor); - state = ConnectionState::CLOSED; - logger.error() << error.what(); - return; - } - logger.info() << "connected to " << to_string(addr); - state = ConnectionState::CONNECTED; - } -public: - SocketTcpConnection(SOCKET descriptor, sockaddr_in addr) - : descriptor(descriptor), addr(std::move(addr)), buffer(16'384) {} - - ~SocketTcpConnection() { - if (state != ConnectionState::CLOSED) { - shutdown(descriptor, 2); - } - if (thread) { - thread->join(); - } - } - - void setNoDelay(bool noDelay) override { - int opt = noDelay ? 1 : 0; - if (setsockopt(descriptor, IPPROTO_TCP, TCP_NODELAY, (char*)&opt, sizeof(opt)) < 0) { - throw handle_socket_error("setsockopt(TCP_NODELAY) failed"); - } - } - - bool isNoDelay() const override { - int opt = 0; - socklen_t len = sizeof(opt); - if (getsockopt(descriptor, IPPROTO_TCP, TCP_NODELAY, (char*)&opt, &len) < 0) { - throw handle_socket_error("getsockopt(TCP_NODELAY) failed"); - } - return opt != 0; - } - - void startListen() { - while (state == ConnectionState::CONNECTED) { - int size = recvsocket(descriptor, buffer.data(), buffer.size()); - if (size == 0) { - logger.info() << "closed connection with " << to_string(addr); - closesocket(descriptor); - state = ConnectionState::CLOSED; - break; - } else if (size < 0) { - logger.warning() << "an error ocurred while receiving from " - << to_string(addr); - auto error = handle_socket_error("recv(...) error"); - closesocket(descriptor); - state = ConnectionState::CLOSED; - logger.error() << error.what(); - break; - } - { - std::lock_guard lock(mutex); - for (size_t i = 0; i < size; i++) { - readBatch.emplace_back(buffer[i]); - } - totalDownload += size; - } - } - } - - void startClient() { - state = ConnectionState::CONNECTED; - thread = std::make_unique([this]() { startListen();}); - } - - void connect(runnable callback) override { - thread = std::make_unique([this, callback]() { - connectSocket(); - if (state == ConnectionState::CONNECTED) { - callback(); - } - startListen(); - }); - } - - int recv(char* buffer, size_t length) override { - std::lock_guard lock(mutex); - - if (state != ConnectionState::CONNECTED && readBatch.empty()) { - return -1; - } - int size = std::min(readBatch.size(), length); - std::memcpy(buffer, readBatch.data(), size); - readBatch.erase(readBatch.begin(), readBatch.begin() + size); - return size; - } - - int send(const char* buffer, size_t length) override { - if (state == ConnectionState::CLOSED) { - return 0; - } - int len = sendsocket(descriptor, buffer, length, 0); - if (len == -1) { - int err = errno; - close(); - throw std::runtime_error( - "Send failed [errno=" + std::to_string(err) + "]: " - + std::string(strerror(err)) - ); - } - totalUpload += len; - return len; - } - - int available() override { - std::lock_guard lock(mutex); - return readBatch.size(); - } - - void close(bool discardAll=false) override { - { - std::lock_guard lock(mutex); - readBatch.clear(); - - if (state != ConnectionState::CLOSED) { - shutdown(descriptor, 2); - closesocket(descriptor); - } - } - if (thread) { - thread->join(); - thread = nullptr; - } - } - - size_t pullUpload() override { - size_t size = totalUpload; - totalUpload = 0; - return size; - } - - size_t pullDownload() override { - size_t size = totalDownload; - totalDownload = 0; - return size; - } - - int getPort() const override { - return htons(addr.sin_port); - } - - std::string getAddress() const override { - return to_string(addr, false); - } - - static std::shared_ptr connect( + std::shared_ptr connect_tcp( const std::string& address, int port, runnable callback - ) { - addrinfo hints {}; + ); - hints.ai_family = AF_INET; - hints.ai_socktype = SOCK_STREAM; - - addrinfo* addrinfo = nullptr; - if (int res = getaddrinfo( - address.c_str(), nullptr, &hints, &addrinfo - )) { - throw std::runtime_error(gai_strerror(res)); - } - - sockaddr_in serverAddress; - std::memcpy(&serverAddress, addrinfo->ai_addr, sizeof(sockaddr_in)); - serverAddress.sin_port = htons(port); - freeaddrinfo(addrinfo); - - SOCKET descriptor = socket(AF_INET, SOCK_STREAM, 0); - if (descriptor == -1) { - throw std::runtime_error("Could not create socket"); - } - auto socket = std::make_shared(descriptor, std::move(serverAddress)); - socket->connect(std::move(callback)); - return socket; - } - - ConnectionState getState() const override { - return state; - } -}; - -class SocketTcpServer : public TcpServer { - u64id_t id; - Network* network; - SOCKET descriptor; - std::vector clients; - std::mutex clientsMutex; - bool open = true; - std::unique_ptr thread = nullptr; - int port; -public: - SocketTcpServer(u64id_t id, Network* network, SOCKET descriptor, int port) - : id(id), network(network), descriptor(descriptor), port(port) {} - - ~SocketTcpServer() { - closeSocket(); - } - - void startListen(ConnectCallback handler) override { - thread = std::make_unique([this, handler]() { - while (open) { - logger.info() << "listening for connections"; - if (listen(descriptor, 2) < 0) { - close(); - break; - } - socklen_t addrlen = sizeof(sockaddr_in); - SOCKET clientDescriptor; - sockaddr_in address; - logger.info() << "accepting clients"; - if ((clientDescriptor = accept(descriptor, (sockaddr*)&address, &addrlen)) == -1) { - close(); - break; - } - logger.info() << "client connected: " << to_string(address); - auto socket = std::make_shared( - clientDescriptor, address - ); - socket->startClient(); - u64id_t id = network->addConnection(socket); - { - std::lock_guard lock(clientsMutex); - clients.push_back(id); - } - handler(this->id, id); - } - }); - } - - void closeSocket() { - if (!open) { - return; - } - logger.info() << "closing server"; - open = false; - - { - std::lock_guard lock(clientsMutex); - for (u64id_t clientid : clients) { - if (auto client = network->getConnection(clientid)) { - client->close(); - } - } - } - clients.clear(); - - shutdown(descriptor, 2); - closesocket(descriptor); - thread->join(); - } - - void close() override { - closeSocket(); - } - - bool isOpen() override { - return open; - } - - int getPort() const override { - return port; - } - - static std::shared_ptr openServer( + std::shared_ptr open_tcp_server( u64id_t id, Network* network, int port, ConnectCallback handler - ) { - SOCKET descriptor = socket( - AF_INET, SOCK_STREAM, 0 - ); - if (descriptor == -1) { - throw std::runtime_error("Could not create server socket"); - } - int opt = 1; - int flags = SO_REUSEADDR; -# if !defined(_WIN32) && !defined(__APPLE__) - flags |= SO_REUSEPORT; -# endif - if (setsockopt(descriptor, SOL_SOCKET, flags, (const char*)&opt, sizeof(opt))) { - logger.error() << "setsockopt(SO_REUSEADDR) failed with errno: " - << errno << "(" << std::strerror(errno) << ")"; - closesocket(descriptor); - throw std::runtime_error("setsockopt"); - } - sockaddr_in address; - address.sin_family = AF_INET; - address.sin_addr.s_addr = INADDR_ANY; - address.sin_port = htons(port); - if (bind(descriptor, (sockaddr*)&address, sizeof(address)) < 0) { - closesocket(descriptor); - throw std::runtime_error("could not bind port "+std::to_string(port)); - } - logger.info() << "opened server at port " << port; - auto server = - std::make_shared(id, network, descriptor, port); - server->startListen(std::move(handler)); - return server; - } -}; + ); -static sockaddr_in resolve_address_dgram(const std::string& address, int port) { - sockaddr_in serverAddr{}; - addrinfo hints {}; - - hints.ai_family = AF_INET; - hints.ai_socktype = SOCK_DGRAM; - - addrinfo* addrinfo = nullptr; - if (int res = getaddrinfo( - address.c_str(), nullptr, &hints, &addrinfo - )) { - throw std::runtime_error(gai_strerror(res)); - } - - std::memcpy(&serverAddr, addrinfo->ai_addr, sizeof(sockaddr_in)); - serverAddr.sin_port = htons(port); - freeaddrinfo(addrinfo); - return serverAddr; -} - -class SocketUdpConnection : public UdpConnection { - u64id_t id; - SOCKET descriptor; - sockaddr_in addr{}; - bool open = true; - std::unique_ptr thread; - ClientDatagramCallback callback; - - size_t totalUpload = 0; - size_t totalDownload = 0; - ConnectionState state = ConnectionState::INITIAL; - -public: - SocketUdpConnection(u64id_t id, SOCKET descriptor, sockaddr_in addr) - : id(id), descriptor(descriptor), addr(std::move(addr)) {} - - ~SocketUdpConnection() override { - SocketUdpConnection::close(); - } - - static std::shared_ptr connect( + std::shared_ptr connect_udp( u64id_t id, const std::string& address, int port, ClientDatagramCallback handler, runnable callback - ) { - SOCKET descriptor = socket(AF_INET, SOCK_DGRAM, 0); - if (descriptor == -1) { - throw std::runtime_error("could not create udp socket"); - } + ); - sockaddr_in serverAddr = resolve_address_dgram(address, port); + std::shared_ptr open_udp_server( + u64id_t id, + Network* network, + int port, + const ServerDatagramCallback& handler + ); +} - if (::connect(descriptor, (sockaddr*)&serverAddr, sizeof(serverAddr)) < 0) { - auto err = handle_socket_error("udp connect failed"); - closesocket(descriptor); - throw err; - } - - auto socket = std::make_shared(id, descriptor, serverAddr); - socket->connect(std::move(handler)); - - callback(); - - return socket; - } - - void connect(ClientDatagramCallback handler) override { - callback = std::move(handler); - state = ConnectionState::CONNECTED; - - thread = std::make_unique([this]() { - util::Buffer buffer(16'384); - while (open) { - int size = recv(descriptor, buffer.data(), buffer.size(), 0); - if (size <= 0) { - logger.error() << "udp connection " << id - << handle_socket_error(" recv error").what(); - if (!open) break; - closesocket(descriptor); - state = ConnectionState::CLOSED; - break; - } - totalDownload += size; - if (callback) { - callback(id, buffer.data(), size); - } - } - }); - } - - int send(const char* buffer, size_t length) override { - int len = ::send(descriptor, buffer, length, 0); - if (len < 0) { - auto err = handle_socket_error(" send failed"); - closesocket(descriptor); - state = ConnectionState::CLOSED; - logger.error() << "udp connection " << id << err.what(); - } else totalUpload += len; - - return len; - } - - void close(bool discardAll=false) override { - if (!open) return; - open = false; - logger.info() << "closing udp connection "<< id; - - if (state != ConnectionState::CLOSED) { - shutdown(descriptor, 2); - closesocket(descriptor); - } - - if (thread) { - thread->join(); - thread.reset(); - } - state = ConnectionState::CLOSED; - } - - size_t pullUpload() override { - size_t s = totalUpload; - totalUpload = 0; - return s; - } - - size_t pullDownload() override { - size_t s = totalDownload; - totalDownload = 0; - return s; - } - - [[nodiscard]] int getPort() const override { - return ntohs(addr.sin_port); - } - - [[nodiscard]] std::string getAddress() const override { - return to_string(addr, false); - } - - [[nodiscard]] ConnectionState getState() const override { - return state; - } -}; - -class SocketUdpServer : public UdpServer { - u64id_t id; - SOCKET descriptor; - bool open = true; - std::unique_ptr thread = nullptr; - int port; - ServerDatagramCallback callback; - -public: - SocketUdpServer(u64id_t id, Network* network, SOCKET descriptor, int port) - : id(id), descriptor(descriptor), port(port) {} - - ~SocketUdpServer() override { - SocketUdpServer::close(); - } - - void startListen(ServerDatagramCallback handler) override { - callback = std::move(handler); - - thread = std::make_unique([this]() { - util::Buffer buffer(16384); - sockaddr_in clientAddr{}; - socklen_t addrlen = sizeof(clientAddr); - - while (open) { - int size = recvfrom(descriptor, buffer.data(), buffer.size(), 0, - reinterpret_cast(&clientAddr), &addrlen); - if (size <= 0) { - if (!open) break; - continue; - } - - std::string addrStr = to_string(clientAddr, false); - int port = ntohs(clientAddr.sin_port); - - callback(id, addrStr, port, buffer.data(), size); - } - }); - } - - void sendTo(const std::string& addr, int port, const char* buffer, size_t length) override { - sockaddr_in client = resolve_address_dgram(addr, port); - if (sendto(descriptor, buffer, length, 0, - reinterpret_cast(&client), sizeof(client)) < 0) { - logger.error() << handle_socket_error("sendto").what(); - } - } - - void close() override { - if (!open) return; - open = false; - shutdown(descriptor, 2); - closesocket(descriptor); - if (thread) { - thread->join(); - thread = nullptr; - } - } - - bool isOpen() override { return open; } - int getPort() const override { return port; } - - static std::shared_ptr openServer( - u64id_t id, Network* network, int port, const ServerDatagramCallback& handler - ) { - SOCKET descriptor = socket(AF_INET, SOCK_DGRAM, 0); - if (descriptor == -1) throw std::runtime_error("could not create udp socket"); - - sockaddr_in address{}; - address.sin_family = AF_INET; - address.sin_addr.s_addr = INADDR_ANY; - address.sin_port = htons(port); - - if (bind(descriptor, (sockaddr*)&address, sizeof(address)) < 0) { - closesocket(descriptor); - throw std::runtime_error("could not bind udp port " + std::to_string(port)); - } - - auto server = std::make_shared(id, network, descriptor, port); - server->startListen(std::move(handler)); - return server; - } -}; Network::Network(std::unique_ptr requests) : requests(std::move(requests)) { @@ -894,19 +69,19 @@ void Network::post( ); } -Connection* Network::getConnection(u64id_t id) { +Connection* Network::getConnection(u64id_t id, bool includePrivate) { std::lock_guard lock(connectionsMutex); const auto& found = connections.find(id); - if (found == connections.end()) { + if (found == connections.end() || (!includePrivate && found->second->isPrivate())) { return nullptr; } return found->second.get(); } -Server* Network::getServer(u64id_t id) const { +Server* Network::getServer(u64id_t id, bool includePrivate) const { const auto& found = servers.find(id); - if (found == servers.end()) { + if (found == servers.end() || (!includePrivate && found->second->isPrivate())) { return nullptr; } return found->second.get(); @@ -916,7 +91,7 @@ u64id_t Network::connectTcp(const std::string& address, int port, consumer& callback, ClientDatagramCallback handler) { +u64id_t Network::connectUdp( + const std::string& address, + int port, + const consumer& callback, + ClientDatagramCallback handler +) { std::lock_guard lock(connectionsMutex); u64id_t id = nextConnection++; - auto socket = SocketUdpConnection::connect(id, address, port, std::move(handler), [id, callback]() { + auto socket = connect_udp(id, address, port, std::move(handler), [id, callback]() { callback(id); }); connections[id] = std::move(socket); @@ -943,7 +123,7 @@ u64id_t Network::connectUdp(const std::string& address, int port, const consumer u64id_t Network::openUdpServer(int port, const ServerDatagramCallback& handler) { u64id_t id = nextServer++; - auto server = SocketUdpServer::openServer(id, this, port, handler); + auto server = open_udp_server(id, this, port, handler); servers[id] = std::move(server); return id; } @@ -997,6 +177,5 @@ void Network::update() { } std::unique_ptr Network::create(const NetworkSettings& settings) { - auto requests = CurlRequests::create(); - return std::make_unique(std::move(requests)); + return std::make_unique(network::create_curl_requests()); } diff --git a/src/network/Network.hpp b/src/network/Network.hpp index 3fec46e0..a6c2fca2 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -1,75 +1,8 @@ #pragma once -#include -#include -#include - -#include "typedefs.hpp" -#include "settings.hpp" -#include "util/Buffer.hpp" -#include "delegates.hpp" +#include "commons.hpp" namespace network { - using OnResponse = std::function)>; - using OnReject = std::function)>; - using ConnectCallback = std::function; - using ServerDatagramCallback = std::function; - using ClientDatagramCallback = std::function; - - class Requests { - public: - virtual ~Requests() {} - - virtual void get( - const std::string& url, - OnResponse onResponse, - OnReject onReject=nullptr, - std::vector headers = {}, - long maxSize=0 - ) = 0; - - virtual void post( - const std::string& url, - const std::string& data, - OnResponse onResponse, - OnReject onReject=nullptr, - std::vector headers = {}, - long maxSize=0 - ) = 0; - - [[nodiscard]] virtual size_t getTotalUpload() const = 0; - [[nodiscard]] virtual size_t getTotalDownload() const = 0; - - virtual void update() = 0; - }; - - enum class ConnectionState { - INITIAL, CONNECTING, CONNECTED, CLOSED - }; - - enum class TransportType { - TCP, UDP - }; - - class Connection { - public: - virtual ~Connection() = default; - - virtual void close(bool discardAll=false) = 0; - - virtual int send(const char* buffer, size_t length) = 0; - - virtual size_t pullUpload() = 0; - virtual size_t pullDownload() = 0; - - [[nodiscard]] virtual int getPort() const = 0; - [[nodiscard]] virtual std::string getAddress() const = 0; - - [[nodiscard]] virtual ConnectionState getState() const = 0; - - [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0; - }; - class TcpConnection : public Connection { public: ~TcpConnection() override = default; @@ -96,15 +29,6 @@ namespace network { } }; - class Server { - public: - virtual ~Server() = default; - virtual void close() = 0; - virtual bool isOpen() = 0; - [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0; - [[nodiscard]] virtual int getPort() const = 0; - }; - class TcpServer : public Server { public: ~TcpServer() override {} @@ -160,8 +84,8 @@ namespace network { long maxSize=0 ); - [[nodiscard]] Connection* getConnection(u64id_t id); - [[nodiscard]] Server* getServer(u64id_t id) const; + [[nodiscard]] Connection* getConnection(u64id_t id, bool includePrivate); + [[nodiscard]] Server* getServer(u64id_t id, bool includePrivate) const; u64id_t connectTcp(const std::string& address, int port, consumer callback); u64id_t connectUdp(const std::string& address, int port, const consumer& callback, ClientDatagramCallback handler); diff --git a/src/network/Sockets.cpp b/src/network/Sockets.cpp new file mode 100644 index 00000000..5152028b --- /dev/null +++ b/src/network/Sockets.cpp @@ -0,0 +1,679 @@ +#include "commons.hpp" + +#pragma comment(lib, "Ws2_32.lib") + +#define NOMINMAX +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using SOCKET = int; +#endif // _WIN32 + +#include "Network.hpp" +#include "util/stringutil.hpp" +#include "debug/Logger.hpp" + +using namespace network; + +static debug::Logger logger("sockets"); + +#ifndef _WIN32 +static inline int closesocket(int descriptor) noexcept { + return close(descriptor); +} +static inline std::runtime_error handle_socket_error(const std::string& message) { + int err = errno; + return std::runtime_error( + message+" [errno=" + std::to_string(err) + "]: " + + std::string(strerror(err)) + ); +} +#else +static inline std::runtime_error handle_socket_error(const std::string& message) { + int errorCode = WSAGetLastError(); + wchar_t* s = nullptr; + size_t size = FormatMessageW( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, + errorCode, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR)&s, + 0, + nullptr + ); + assert(s != nullptr); + while (size && isspace(s[size-1])) { + s[--size] = 0; + } + auto errorString = util::wstr2str_utf8(std::wstring(s)); + LocalFree(s); + return std::runtime_error(message+" [WSA error=" + + std::to_string(errorCode) + "]: "+errorString); +} +#endif + +static inline int connectsocket( + int descriptor, const sockaddr* addr, socklen_t len +) noexcept { + return connect(descriptor, addr, len); +} + +static inline int recvsocket( + int descriptor, char* buf, size_t len +) noexcept { + return recv(descriptor, buf, len, 0); +} + +static inline int sendsocket( + int descriptor, const char* buf, size_t len, int flags +) noexcept { + return send(descriptor, buf, len, flags); +} + +static std::string to_string(const sockaddr_in& addr, bool port=true) { + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &(addr.sin_addr), ip, INET_ADDRSTRLEN)) { + return std::string(ip) + + (port ? (":" + std::to_string(htons(addr.sin_port))) : ""); + } + return ""; +} + +class SocketTcpConnection : public TcpConnection { + SOCKET descriptor; + sockaddr_in addr; + size_t totalUpload = 0; + size_t totalDownload = 0; + ConnectionState state = ConnectionState::INITIAL; + std::unique_ptr thread = nullptr; + std::vector readBatch; + util::Buffer buffer; + std::mutex mutex; + + void connectSocket() { + state = ConnectionState::CONNECTING; + logger.info() << "connecting to " << to_string(addr); + int res = connectsocket(descriptor, (const sockaddr*)&addr, sizeof(sockaddr_in)); + if (res < 0) { + auto error = handle_socket_error("Connect failed"); + closesocket(descriptor); + state = ConnectionState::CLOSED; + logger.error() << error.what(); + return; + } + logger.info() << "connected to " << to_string(addr); + state = ConnectionState::CONNECTED; + } +public: + SocketTcpConnection(SOCKET descriptor, sockaddr_in addr) + : descriptor(descriptor), addr(std::move(addr)), buffer(16'384) {} + + ~SocketTcpConnection() { + if (state != ConnectionState::CLOSED) { + shutdown(descriptor, 2); + } + if (thread) { + thread->join(); + } + } + + void setNoDelay(bool noDelay) override { + int opt = noDelay ? 1 : 0; + if (setsockopt(descriptor, IPPROTO_TCP, TCP_NODELAY, (char*)&opt, sizeof(opt)) < 0) { + throw handle_socket_error("setsockopt(TCP_NODELAY) failed"); + } + } + + bool isNoDelay() const override { + int opt = 0; + socklen_t len = sizeof(opt); + if (getsockopt(descriptor, IPPROTO_TCP, TCP_NODELAY, (char*)&opt, &len) < 0) { + throw handle_socket_error("getsockopt(TCP_NODELAY) failed"); + } + return opt != 0; + } + + void startListen() { + while (state == ConnectionState::CONNECTED) { + int size = recvsocket(descriptor, buffer.data(), buffer.size()); + if (size == 0) { + logger.info() << "closed connection with " << to_string(addr); + closesocket(descriptor); + state = ConnectionState::CLOSED; + break; + } else if (size < 0) { + logger.warning() << "an error ocurred while receiving from " + << to_string(addr); + auto error = handle_socket_error("recv(...) error"); + closesocket(descriptor); + state = ConnectionState::CLOSED; + logger.error() << error.what(); + break; + } + { + std::lock_guard lock(mutex); + for (size_t i = 0; i < size; i++) { + readBatch.emplace_back(buffer[i]); + } + totalDownload += size; + } + } + } + + void startClient() { + state = ConnectionState::CONNECTED; + thread = std::make_unique([this]() { startListen();}); + } + + void connect(runnable callback) override { + thread = std::make_unique([this, callback]() { + connectSocket(); + if (state == ConnectionState::CONNECTED) { + callback(); + } + startListen(); + }); + } + + int recv(char* buffer, size_t length) override { + std::lock_guard lock(mutex); + + if (state != ConnectionState::CONNECTED && readBatch.empty()) { + return -1; + } + int size = std::min(readBatch.size(), length); + std::memcpy(buffer, readBatch.data(), size); + readBatch.erase(readBatch.begin(), readBatch.begin() + size); + return size; + } + + int send(const char* buffer, size_t length) override { + if (state == ConnectionState::CLOSED) { + return 0; + } + int len = sendsocket(descriptor, buffer, length, 0); + if (len == -1) { + int err = errno; + close(); + throw std::runtime_error( + "Send failed [errno=" + std::to_string(err) + "]: " + + std::string(strerror(err)) + ); + } + totalUpload += len; + return len; + } + + int available() override { + std::lock_guard lock(mutex); + return readBatch.size(); + } + + void close(bool discardAll=false) override { + { + std::lock_guard lock(mutex); + readBatch.clear(); + + if (state != ConnectionState::CLOSED) { + shutdown(descriptor, 2); + closesocket(descriptor); + } + } + if (thread) { + thread->join(); + thread = nullptr; + } + } + + size_t pullUpload() override { + size_t size = totalUpload; + totalUpload = 0; + return size; + } + + size_t pullDownload() override { + size_t size = totalDownload; + totalDownload = 0; + return size; + } + + int getPort() const override { + return htons(addr.sin_port); + } + + std::string getAddress() const override { + return to_string(addr, false); + } + + static std::shared_ptr connect( + const std::string& address, int port, runnable callback + ) { + addrinfo hints {}; + + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + addrinfo* addrinfo = nullptr; + if (int res = getaddrinfo( + address.c_str(), nullptr, &hints, &addrinfo + )) { + throw std::runtime_error(gai_strerror(res)); + } + + sockaddr_in serverAddress; + std::memcpy(&serverAddress, addrinfo->ai_addr, sizeof(sockaddr_in)); + serverAddress.sin_port = htons(port); + freeaddrinfo(addrinfo); + + SOCKET descriptor = socket(AF_INET, SOCK_STREAM, 0); + if (descriptor == -1) { + throw std::runtime_error("Could not create socket"); + } + auto socket = std::make_shared(descriptor, std::move(serverAddress)); + socket->connect(std::move(callback)); + return socket; + } + + ConnectionState getState() const override { + return state; + } +}; + +class SocketTcpServer : public TcpServer { + u64id_t id; + Network* network; + SOCKET descriptor; + std::vector clients; + std::mutex clientsMutex; + bool open = true; + std::unique_ptr thread = nullptr; + int port; +public: + SocketTcpServer(u64id_t id, Network* network, SOCKET descriptor, int port) + : id(id), network(network), descriptor(descriptor), port(port) {} + + ~SocketTcpServer() { + closeSocket(); + } + + void startListen(ConnectCallback handler) override { + thread = std::make_unique([this, handler]() { + while (open) { + logger.info() << "listening for connections"; + if (listen(descriptor, 2) < 0) { + close(); + break; + } + socklen_t addrlen = sizeof(sockaddr_in); + SOCKET clientDescriptor; + sockaddr_in address; + logger.info() << "accepting clients"; + if ((clientDescriptor = accept(descriptor, (sockaddr*)&address, &addrlen)) == -1) { + close(); + break; + } + logger.info() << "client connected: " << to_string(address); + auto socket = std::make_shared( + clientDescriptor, address + ); + socket->startClient(); + u64id_t id = network->addConnection(socket); + { + std::lock_guard lock(clientsMutex); + clients.push_back(id); + } + handler(this->id, id); + } + }); + } + + void closeSocket() { + if (!open) { + return; + } + logger.info() << "closing server"; + open = false; + + { + std::lock_guard lock(clientsMutex); + for (u64id_t clientid : clients) { + if (auto client = network->getConnection(clientid, true)) { + client->close(); + } + } + } + clients.clear(); + + shutdown(descriptor, 2); + closesocket(descriptor); + thread->join(); + } + + void close() override { + closeSocket(); + } + + bool isOpen() override { + return open; + } + + int getPort() const override { + return port; + } + + static std::shared_ptr openServer( + u64id_t id, Network* network, int port, ConnectCallback handler + ) { + SOCKET descriptor = socket( + AF_INET, SOCK_STREAM, 0 + ); + if (descriptor == -1) { + throw std::runtime_error("Could not create server socket"); + } + int opt = 1; + int flags = SO_REUSEADDR; +# if !defined(_WIN32) && !defined(__APPLE__) + flags |= SO_REUSEPORT; +# endif + if (setsockopt(descriptor, SOL_SOCKET, flags, (const char*)&opt, sizeof(opt))) { + logger.error() << "setsockopt(SO_REUSEADDR) failed with errno: " + << errno << "(" << std::strerror(errno) << ")"; + closesocket(descriptor); + throw std::runtime_error("setsockopt"); + } + sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + if (bind(descriptor, (sockaddr*)&address, sizeof(address)) < 0) { + closesocket(descriptor); + throw std::runtime_error("could not bind port "+std::to_string(port)); + } + logger.info() << "opened server at port " << port; + auto server = + std::make_shared(id, network, descriptor, port); + server->startListen(std::move(handler)); + return server; + } +}; + +static sockaddr_in resolve_address_dgram(const std::string& address, int port) { + sockaddr_in serverAddr{}; + addrinfo hints {}; + + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_DGRAM; + + addrinfo* addrinfo = nullptr; + if (int res = getaddrinfo( + address.c_str(), nullptr, &hints, &addrinfo + )) { + throw std::runtime_error(gai_strerror(res)); + } + + std::memcpy(&serverAddr, addrinfo->ai_addr, sizeof(sockaddr_in)); + serverAddr.sin_port = htons(port); + freeaddrinfo(addrinfo); + return serverAddr; +} + +class SocketUdpConnection : public UdpConnection { + u64id_t id; + SOCKET descriptor; + sockaddr_in addr{}; + bool open = true; + std::unique_ptr thread; + ClientDatagramCallback callback; + + size_t totalUpload = 0; + size_t totalDownload = 0; + ConnectionState state = ConnectionState::INITIAL; + +public: + SocketUdpConnection(u64id_t id, SOCKET descriptor, sockaddr_in addr) + : id(id), descriptor(descriptor), addr(std::move(addr)) {} + + ~SocketUdpConnection() override { + SocketUdpConnection::close(); + } + + static std::shared_ptr connect( + u64id_t id, + const std::string& address, + int port, + ClientDatagramCallback handler, + runnable callback + ) { + SOCKET descriptor = socket(AF_INET, SOCK_DGRAM, 0); + if (descriptor == -1) { + throw std::runtime_error("could not create udp socket"); + } + + sockaddr_in serverAddr = resolve_address_dgram(address, port); + + if (::connect(descriptor, (sockaddr*)&serverAddr, sizeof(serverAddr)) < 0) { + auto err = handle_socket_error("udp connect failed"); + closesocket(descriptor); + throw err; + } + + auto socket = std::make_shared(id, descriptor, serverAddr); + socket->connect(std::move(handler)); + + callback(); + + return socket; + } + + void connect(ClientDatagramCallback handler) override { + callback = std::move(handler); + state = ConnectionState::CONNECTED; + + thread = std::make_unique([this]() { + util::Buffer buffer(16'384); + while (open) { + int size = recv(descriptor, buffer.data(), buffer.size(), 0); + if (size <= 0) { + logger.error() << "udp connection " << id + << handle_socket_error(" recv error").what(); + if (!open) break; + closesocket(descriptor); + state = ConnectionState::CLOSED; + break; + } + totalDownload += size; + if (callback) { + callback(id, buffer.data(), size); + } + } + }); + } + + int send(const char* buffer, size_t length) override { + int len = ::send(descriptor, buffer, length, 0); + if (len < 0) { + auto err = handle_socket_error(" send failed"); + closesocket(descriptor); + state = ConnectionState::CLOSED; + logger.error() << "udp connection " << id << err.what(); + } else totalUpload += len; + + return len; + } + + void close(bool discardAll=false) override { + if (!open) return; + open = false; + logger.info() << "closing udp connection "<< id; + + if (state != ConnectionState::CLOSED) { + shutdown(descriptor, 2); + closesocket(descriptor); + } + + if (thread) { + thread->join(); + thread.reset(); + } + state = ConnectionState::CLOSED; + } + + size_t pullUpload() override { + size_t s = totalUpload; + totalUpload = 0; + return s; + } + + size_t pullDownload() override { + size_t s = totalDownload; + totalDownload = 0; + return s; + } + + [[nodiscard]] int getPort() const override { + return ntohs(addr.sin_port); + } + + [[nodiscard]] std::string getAddress() const override { + return to_string(addr, false); + } + + [[nodiscard]] ConnectionState getState() const override { + return state; + } +}; + +class SocketUdpServer : public UdpServer { + u64id_t id; + SOCKET descriptor; + bool open = true; + std::unique_ptr thread = nullptr; + int port; + ServerDatagramCallback callback; + +public: + SocketUdpServer(u64id_t id, Network* network, SOCKET descriptor, int port) + : id(id), descriptor(descriptor), port(port) {} + + ~SocketUdpServer() override { + SocketUdpServer::close(); + } + + void startListen(ServerDatagramCallback handler) override { + callback = std::move(handler); + + thread = std::make_unique([this]() { + util::Buffer buffer(16384); + sockaddr_in clientAddr{}; + socklen_t addrlen = sizeof(clientAddr); + + while (open) { + int size = recvfrom(descriptor, buffer.data(), buffer.size(), 0, + reinterpret_cast(&clientAddr), &addrlen); + if (size <= 0) { + if (!open) break; + continue; + } + + std::string addrStr = to_string(clientAddr, false); + int port = ntohs(clientAddr.sin_port); + + callback(id, addrStr, port, buffer.data(), size); + } + }); + } + + void sendTo(const std::string& addr, int port, const char* buffer, size_t length) override { + sockaddr_in client = resolve_address_dgram(addr, port); + if (sendto(descriptor, buffer, length, 0, + reinterpret_cast(&client), sizeof(client)) < 0) { + logger.error() << handle_socket_error("sendto").what(); + } + } + + void close() override { + if (!open) return; + open = false; + shutdown(descriptor, 2); + closesocket(descriptor); + if (thread) { + thread->join(); + thread = nullptr; + } + } + + bool isOpen() override { return open; } + int getPort() const override { return port; } + + static std::shared_ptr openServer( + u64id_t id, Network* network, int port, const ServerDatagramCallback& handler + ) { + SOCKET descriptor = socket(AF_INET, SOCK_DGRAM, 0); + if (descriptor == -1) throw std::runtime_error("could not create udp socket"); + + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + if (bind(descriptor, (sockaddr*)&address, sizeof(address)) < 0) { + closesocket(descriptor); + throw std::runtime_error("could not bind udp port " + std::to_string(port)); + } + + auto server = std::make_shared(id, network, descriptor, port); + server->startListen(std::move(handler)); + return server; + } +}; + +namespace network { + std::shared_ptr connect_tcp( + const std::string& address, int port, runnable callback + ) { + return SocketTcpConnection::connect(address, port, std::move(callback)); + } + + std::shared_ptr open_tcp_server( + u64id_t id, Network* network, int port, ConnectCallback handler + ) { + return SocketTcpServer::openServer(id, network, port, std::move(handler)); + } + + std::shared_ptr connect_udp( + u64id_t id, + const std::string& address, + int port, + ClientDatagramCallback handler, + runnable callback + ) { + return SocketUdpConnection::connect( + id, address, port, std::move(handler), std::move(callback) + ); + } + + std::shared_ptr open_udp_server( + u64id_t id, + Network* network, + int port, + const ServerDatagramCallback& handler + ) { + return SocketUdpServer::openServer(id, network, port, handler); + } +} diff --git a/src/network/commons.hpp b/src/network/commons.hpp new file mode 100644 index 00000000..83881c52 --- /dev/null +++ b/src/network/commons.hpp @@ -0,0 +1,91 @@ +#pragma once + +#include "typedefs.hpp" +#include "settings.hpp" +#include "util/Buffer.hpp" +#include "delegates.hpp" + +#include +#include +#include + +namespace network { + using OnResponse = std::function)>; + using OnReject = std::function)>; + using ConnectCallback = std::function; + using ServerDatagramCallback = std::function; + using ClientDatagramCallback = std::function; + + class Requests { + public: + virtual ~Requests() {} + + virtual void get( + const std::string& url, + OnResponse onResponse, + OnReject onReject=nullptr, + std::vector headers = {}, + long maxSize=0 + ) = 0; + + virtual void post( + const std::string& url, + const std::string& data, + OnResponse onResponse, + OnReject onReject=nullptr, + std::vector headers = {}, + long maxSize=0 + ) = 0; + + [[nodiscard]] virtual size_t getTotalUpload() const = 0; + [[nodiscard]] virtual size_t getTotalDownload() const = 0; + + virtual void update() = 0; + }; + + enum class ConnectionState { + INITIAL, CONNECTING, CONNECTED, CLOSED + }; + + enum class TransportType { + TCP, UDP + }; + + class Connection { + public: + virtual ~Connection() = default; + + virtual void close(bool discardAll=false) = 0; + + virtual int send(const char* buffer, size_t length) = 0; + + virtual size_t pullUpload() = 0; + virtual size_t pullDownload() = 0; + + bool isPrivate() const { return isprivate; } + void setPrivate(bool flag) {isprivate = flag;} + + [[nodiscard]] virtual int getPort() const = 0; + [[nodiscard]] virtual std::string getAddress() const = 0; + + [[nodiscard]] virtual ConnectionState getState() const = 0; + + [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0; + protected: + bool isprivate = false; + }; + + class Server { + public: + virtual ~Server() = default; + virtual void close() = 0; + virtual bool isOpen() = 0; + [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0; + [[nodiscard]] virtual int getPort() const = 0; + + bool isPrivate() const { return isprivate; } + void setPrivate(bool flag) { isprivate = flag; } + protected: + bool isprivate = false; + }; +}