From fb0f4bff5292fd2e28dd57132d239f075b9e9afc Mon Sep 17 00:00:00 2001 From: MihailRis Date: Wed, 27 Nov 2024 12:10:59 +0300 Subject: [PATCH] complete simple connection implementation --- src/logic/scripting/lua/libs/libnetwork.cpp | 56 +++++++- src/network/Network.cpp | 139 ++++++++++++++------ src/network/Network.hpp | 18 ++- test/network/curltest.cpp | 10 -- 4 files changed, 164 insertions(+), 59 deletions(-) diff --git a/src/logic/scripting/lua/libs/libnetwork.cpp b/src/logic/scripting/lua/libs/libnetwork.cpp index 4b734a62..2245dc10 100644 --- a/src/logic/scripting/lua/libs/libnetwork.cpp +++ b/src/logic/scripting/lua/libs/libnetwork.cpp @@ -39,13 +39,67 @@ static int l_get_binary(lua::State* L) { static int l_connect(lua::State* L) { std::string address = lua::require_string(L, 1); int port = lua::tointeger(L, 2); - u64id_t id = engine->getNetwork().connect(address, port); + lua::pushvalue(L, 3); + auto callback = lua::create_lambda(L); + u64id_t id = engine->getNetwork().connect(address, port, [callback](u64id_t id) { + callback({id}); + }); return lua::pushinteger(L, id); } +static int l_send(lua::State* L) { + u64id_t id = lua::tointeger(L, 1); + auto connection = engine->getNetwork().getConnection(id); + + if (lua::istable(L, 2)) { + lua::pushvalue(L, 2); + size_t size = lua::objlen(L, 2); + util::Buffer buffer(size); + for (size_t i = 0; i < size; i++) { + lua::rawgeti(L, i + 1); + buffer[i] = lua::tointeger(L, -1); + lua::pop(L); + } + lua::pop(L); + connection->send(buffer.data(), size); + } else if (auto bytes = lua::touserdata(L, 2)) { + connection->send( + reinterpret_cast(bytes->data().data()), bytes->data().size() + ); + } + return 0; +} + +static int l_recv(lua::State* L) { + u64id_t id = lua::tointeger(L, 1); + int length = lua::tointeger(L, 2); + auto connection = engine->getNetwork().getConnection(id); + util::Buffer buffer(glm::min(length, connection->available())); + + int size = connection->recv(buffer.data(), length); + if (size == -1) { + return 0; + } + if (lua::toboolean(L, 3)) { + lua::createtable(L, size, 0); + for (size_t i = 0; i < size; i++) { + lua::pushinteger(L, buffer[i] & 0xFF); + lua::rawseti(L, i+1); + } + } else { + lua::newuserdata(L, size); + auto bytearray = lua::touserdata(L, -1); + bytearray->data().reserve(size); + std::memcpy(bytearray->data().data(), buffer.data(), size); + } + return 1; +} + const luaL_Reg networklib[] = { {"get", lua::wrap}, {"get_binary", lua::wrap}, {"__connect", lua::wrap}, + {"__send", lua::wrap}, + {"__recv", lua::wrap}, {NULL, NULL} }; diff --git a/src/network/Network.cpp b/src/network/Network.cpp index eb50d0c5..23d0ec87 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #ifdef _WIN32 /// included in curl.h @@ -24,9 +26,6 @@ using SOCKET = int; #endif // _WIN32 -#include -#include - #include "debug/Logger.hpp" #include "util/stringutil.hpp" @@ -262,36 +261,87 @@ static std::string to_string(const addrinfo* addr) { return ""; } -class SocketImpl : public Socket { +class SocketConnection : public Connection { SOCKET descriptor; bool open = true; addrinfo* addr; size_t totalUpload = 0; size_t totalDownload = 0; -public: - SocketImpl(SOCKET descriptor, addrinfo* addr) - : descriptor(descriptor), addr(addr) { - } + ConnectionState state = ConnectionState::INITIAL; + std::unique_ptr thread = nullptr; + std::vector readBatch; + util::Buffer buffer; + std::mutex mutex; - ~SocketImpl() { - closesocket(descriptor); + void connectSocket() { + state = ConnectionState::CONNECTING; + logger.info() << "connecting to " << to_string(addr); + int res = connectsocket(descriptor, addr->ai_addr, addr->ai_addrlen); + if (res < 0) { + auto error = handle_socket_error("Connect failed"); + closesocket(descriptor); + freeaddrinfo(addr); + state = ConnectionState::CLOSED; + throw error; + } + logger.info() << "connected to " << to_string(addr); + state = ConnectionState::CONNECTED; + } +public: + SocketConnection(SOCKET descriptor, addrinfo* addr) + : descriptor(descriptor), addr(addr), buffer(16'384) {} + + ~SocketConnection() { + if (state != ConnectionState::CLOSED) { + shutdown(descriptor, 2); + closesocket(descriptor); + } + thread->join(); freeaddrinfo(addr); } + void connect(runnable callback) override { + thread = std::make_unique([this, callback]() { + connectSocket(); + callback(); + while (state == ConnectionState::CONNECTED) { + int size = recvsocket(descriptor, buffer.data(), buffer.size()); + if (size == 0) { + logger.info() << "closed connection " << to_string(addr); + closesocket(descriptor); + state = ConnectionState::CLOSED; + break; + } else if (size < 0) { + logger.info() << "an error ocurred while receiving from " + << to_string(addr); + auto error = handle_socket_error("recv(...) error"); + closesocket(descriptor); + state = ConnectionState::CLOSED; + logger.error() << error.what(); + break; + } + { + std::lock_guard lock(mutex); + for (size_t i = 0; i < size; i++) { + readBatch.emplace_back(buffer[i]); + } + totalDownload += size; + } + logger.info() << "read " << size << " bytes from " << to_string(addr); + } + }); + } + int recv(char* buffer, size_t length) override { - int len = recvsocket(descriptor, buffer, length); - if (len == 0) { - int err = errno; - close(); - throw std::runtime_error( - "Read failed [errno=" + std::to_string(err) + - "]: " + std::string(strerror(err)) - ); - } else if (len == -1) { - return 0; + std::lock_guard lock(mutex); + + if (state != ConnectionState::CONNECTED && readBatch.empty()) { + return -1; } - totalDownload += len; - return len; + int size = std::min(readBatch.size(), length); + std::memcpy(buffer, readBatch.data(), size); + readBatch.erase(readBatch.begin(), readBatch.begin() + size); + return size; } int send(const char* buffer, size_t length) override { @@ -308,13 +358,17 @@ public: return len; } - void close() override { - closesocket(descriptor); - open = false; + int available() override { + std::lock_guard lock(mutex); + return readBatch.size(); } - bool isOpen() const override { - return open; + void close() override { + if (state != ConnectionState::CLOSED) { + shutdown(descriptor, 2); + closesocket(descriptor); + } + thread->join(); } size_t getTotalUpload() const override { @@ -325,8 +379,8 @@ public: return totalDownload; } - static std::shared_ptr connect( - const std::string& address, int port + static std::shared_ptr connect( + const std::string& address, int port, runnable callback ) { addrinfo hints {}; @@ -346,21 +400,18 @@ public: freeaddrinfo(addrinfo); throw std::runtime_error("Could not create socket"); } - int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen); - if (res < 0) { - auto error = handle_socket_error("Connect failed"); - closesocket(descriptor); - freeaddrinfo(addrinfo); - throw error; - } - logger.info() << "connected to " << address << " [" - << to_string(addrinfo) << ":" << port << "]"; - return std::make_shared(descriptor, addrinfo); + auto socket = std::make_shared(descriptor, addrinfo); + socket->connect(std::move(callback)); + return socket; + } + + ConnectionState getState() const override { + return state; } }; Network::Network(std::unique_ptr requests) - : requests(std::move(requests)) { +: requests(std::move(requests)) { } Network::~Network() = default; @@ -374,7 +425,7 @@ void Network::get( requests->get(url, onResponse, onReject, maxSize); } -Socket* Network::getConnection(u64id_t id) const { +Connection* Network::getConnection(u64id_t id) const { const auto& found = connections.find(id); if (found == connections.end()) { return nullptr; @@ -382,9 +433,11 @@ Socket* Network::getConnection(u64id_t id) const { return found->second.get(); } -u64id_t Network::connect(const std::string& address, int port) { - auto socket = SocketImpl::connect(address, port); +u64id_t Network::connect(const std::string& address, int port, consumer callback) { u64id_t id = nextConnection++; + auto socket = SocketConnection::connect(address, port, [id, callback]() { + callback(id); + }); connections[id] = std::move(socket); return id; } diff --git a/src/network/Network.hpp b/src/network/Network.hpp index 450b8640..f4b6ceef 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -6,6 +6,7 @@ #include "typedefs.hpp" #include "settings.hpp" #include "util/Buffer.hpp" +#include "delegates.hpp" namespace network { using OnResponse = std::function)>; @@ -27,20 +28,27 @@ namespace network { virtual void update() = 0; }; - class Socket { + enum class ConnectionState { + INITIAL, CONNECTING, CONNECTED, CLOSED + }; + + class Connection { public: + virtual void connect(runnable callback) = 0; virtual int recv(char* buffer, size_t length) = 0; virtual int send(const char* buffer, size_t length) = 0; virtual void close() = 0; - virtual bool isOpen() const = 0; + virtual int available() = 0; virtual size_t getTotalUpload() const = 0; virtual size_t getTotalDownload() const = 0; + + virtual ConnectionState getState() const = 0; }; class Network { std::unique_ptr requests; - std::unordered_map> connections; + std::unordered_map> connections; u64id_t nextConnection = 1; public: Network(std::unique_ptr requests); @@ -53,9 +61,9 @@ namespace network { long maxSize=0 ); - Socket* getConnection(u64id_t id) const; + [[nodiscard]] Connection* getConnection(u64id_t id) const; - u64id_t connect(const std::string& address, int port); + u64id_t connect(const std::string& address, int port, consumer callback); size_t getTotalUpload() const; size_t getTotalDownload() const; diff --git a/test/network/curltest.cpp b/test/network/curltest.cpp index afca9677..e365651e 100644 --- a/test/network/curltest.cpp +++ b/test/network/curltest.cpp @@ -18,16 +18,6 @@ TEST(curltest, curltest) { std::cout << value << std::endl; }, [](auto){} ); - if (true) { - auto socket = network->getConnection(network->connect("google.com", 80)); - const char* string = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; - socket->send(string, strlen(string)); - char data[1024]; - - int len = socket->recv(data, 1024); - std::cout << len << " " << std::string(data, len) << std::endl; - } - std::cout << "upload: " << network->getTotalUpload() << " B" << std::endl; std::cout << "download: " << network->getTotalDownload() << " B" << std::endl; }