complete simple connection implementation

This commit is contained in:
MihailRis 2024-11-27 12:10:59 +03:00
parent 3933baccd2
commit fb0f4bff52
4 changed files with 164 additions and 59 deletions

View File

@ -39,13 +39,67 @@ static int l_get_binary(lua::State* L) {
static int l_connect(lua::State* L) {
std::string address = lua::require_string(L, 1);
int port = lua::tointeger(L, 2);
u64id_t id = engine->getNetwork().connect(address, port);
lua::pushvalue(L, 3);
auto callback = lua::create_lambda(L);
u64id_t id = engine->getNetwork().connect(address, port, [callback](u64id_t id) {
callback({id});
});
return lua::pushinteger(L, id);
}
static int l_send(lua::State* L) {
u64id_t id = lua::tointeger(L, 1);
auto connection = engine->getNetwork().getConnection(id);
if (lua::istable(L, 2)) {
lua::pushvalue(L, 2);
size_t size = lua::objlen(L, 2);
util::Buffer<char> buffer(size);
for (size_t i = 0; i < size; i++) {
lua::rawgeti(L, i + 1);
buffer[i] = lua::tointeger(L, -1);
lua::pop(L);
}
lua::pop(L);
connection->send(buffer.data(), size);
} else if (auto bytes = lua::touserdata<lua::LuaBytearray>(L, 2)) {
connection->send(
reinterpret_cast<char*>(bytes->data().data()), bytes->data().size()
);
}
return 0;
}
static int l_recv(lua::State* L) {
u64id_t id = lua::tointeger(L, 1);
int length = lua::tointeger(L, 2);
auto connection = engine->getNetwork().getConnection(id);
util::Buffer<char> buffer(glm::min(length, connection->available()));
int size = connection->recv(buffer.data(), length);
if (size == -1) {
return 0;
}
if (lua::toboolean(L, 3)) {
lua::createtable(L, size, 0);
for (size_t i = 0; i < size; i++) {
lua::pushinteger(L, buffer[i] & 0xFF);
lua::rawseti(L, i+1);
}
} else {
lua::newuserdata<lua::LuaBytearray>(L, size);
auto bytearray = lua::touserdata<lua::LuaBytearray>(L, -1);
bytearray->data().reserve(size);
std::memcpy(bytearray->data().data(), buffer.data(), size);
}
return 1;
}
const luaL_Reg networklib[] = {
{"get", lua::wrap<l_get>},
{"get_binary", lua::wrap<l_get_binary>},
{"__connect", lua::wrap<l_connect>},
{"__send", lua::wrap<l_send>},
{"__recv", lua::wrap<l_recv>},
{NULL, NULL}
};

View File

@ -7,6 +7,8 @@
#include <stdexcept>
#include <limits>
#include <queue>
#include <mutex>
#include <thread>
#ifdef _WIN32
/// included in curl.h
@ -24,9 +26,6 @@
using SOCKET = int;
#endif // _WIN32
#include <chrono>
#include <thread>
#include "debug/Logger.hpp"
#include "util/stringutil.hpp"
@ -262,36 +261,87 @@ static std::string to_string(const addrinfo* addr) {
return "";
}
class SocketImpl : public Socket {
class SocketConnection : public Connection {
SOCKET descriptor;
bool open = true;
addrinfo* addr;
size_t totalUpload = 0;
size_t totalDownload = 0;
public:
SocketImpl(SOCKET descriptor, addrinfo* addr)
: descriptor(descriptor), addr(addr) {
}
ConnectionState state = ConnectionState::INITIAL;
std::unique_ptr<std::thread> thread = nullptr;
std::vector<char> readBatch;
util::Buffer<char> buffer;
std::mutex mutex;
~SocketImpl() {
closesocket(descriptor);
void connectSocket() {
state = ConnectionState::CONNECTING;
logger.info() << "connecting to " << to_string(addr);
int res = connectsocket(descriptor, addr->ai_addr, addr->ai_addrlen);
if (res < 0) {
auto error = handle_socket_error("Connect failed");
closesocket(descriptor);
freeaddrinfo(addr);
state = ConnectionState::CLOSED;
throw error;
}
logger.info() << "connected to " << to_string(addr);
state = ConnectionState::CONNECTED;
}
public:
SocketConnection(SOCKET descriptor, addrinfo* addr)
: descriptor(descriptor), addr(addr), buffer(16'384) {}
~SocketConnection() {
if (state != ConnectionState::CLOSED) {
shutdown(descriptor, 2);
closesocket(descriptor);
}
thread->join();
freeaddrinfo(addr);
}
void connect(runnable callback) override {
thread = std::make_unique<std::thread>([this, callback]() {
connectSocket();
callback();
while (state == ConnectionState::CONNECTED) {
int size = recvsocket(descriptor, buffer.data(), buffer.size());
if (size == 0) {
logger.info() << "closed connection " << to_string(addr);
closesocket(descriptor);
state = ConnectionState::CLOSED;
break;
} else if (size < 0) {
logger.info() << "an error ocurred while receiving from "
<< to_string(addr);
auto error = handle_socket_error("recv(...) error");
closesocket(descriptor);
state = ConnectionState::CLOSED;
logger.error() << error.what();
break;
}
{
std::lock_guard lock(mutex);
for (size_t i = 0; i < size; i++) {
readBatch.emplace_back(buffer[i]);
}
totalDownload += size;
}
logger.info() << "read " << size << " bytes from " << to_string(addr);
}
});
}
int recv(char* buffer, size_t length) override {
int len = recvsocket(descriptor, buffer, length);
if (len == 0) {
int err = errno;
close();
throw std::runtime_error(
"Read failed [errno=" + std::to_string(err) +
"]: " + std::string(strerror(err))
);
} else if (len == -1) {
return 0;
std::lock_guard lock(mutex);
if (state != ConnectionState::CONNECTED && readBatch.empty()) {
return -1;
}
totalDownload += len;
return len;
int size = std::min(readBatch.size(), length);
std::memcpy(buffer, readBatch.data(), size);
readBatch.erase(readBatch.begin(), readBatch.begin() + size);
return size;
}
int send(const char* buffer, size_t length) override {
@ -308,13 +358,17 @@ public:
return len;
}
void close() override {
closesocket(descriptor);
open = false;
int available() override {
std::lock_guard lock(mutex);
return readBatch.size();
}
bool isOpen() const override {
return open;
void close() override {
if (state != ConnectionState::CLOSED) {
shutdown(descriptor, 2);
closesocket(descriptor);
}
thread->join();
}
size_t getTotalUpload() const override {
@ -325,8 +379,8 @@ public:
return totalDownload;
}
static std::shared_ptr<SocketImpl> connect(
const std::string& address, int port
static std::shared_ptr<SocketConnection> connect(
const std::string& address, int port, runnable callback
) {
addrinfo hints {};
@ -346,21 +400,18 @@ public:
freeaddrinfo(addrinfo);
throw std::runtime_error("Could not create socket");
}
int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen);
if (res < 0) {
auto error = handle_socket_error("Connect failed");
closesocket(descriptor);
freeaddrinfo(addrinfo);
throw error;
}
logger.info() << "connected to " << address << " ["
<< to_string(addrinfo) << ":" << port << "]";
return std::make_shared<SocketImpl>(descriptor, addrinfo);
auto socket = std::make_shared<SocketConnection>(descriptor, addrinfo);
socket->connect(std::move(callback));
return socket;
}
ConnectionState getState() const override {
return state;
}
};
Network::Network(std::unique_ptr<Requests> requests)
: requests(std::move(requests)) {
: requests(std::move(requests)) {
}
Network::~Network() = default;
@ -374,7 +425,7 @@ void Network::get(
requests->get(url, onResponse, onReject, maxSize);
}
Socket* Network::getConnection(u64id_t id) const {
Connection* Network::getConnection(u64id_t id) const {
const auto& found = connections.find(id);
if (found == connections.end()) {
return nullptr;
@ -382,9 +433,11 @@ Socket* Network::getConnection(u64id_t id) const {
return found->second.get();
}
u64id_t Network::connect(const std::string& address, int port) {
auto socket = SocketImpl::connect(address, port);
u64id_t Network::connect(const std::string& address, int port, consumer<u64id_t> callback) {
u64id_t id = nextConnection++;
auto socket = SocketConnection::connect(address, port, [id, callback]() {
callback(id);
});
connections[id] = std::move(socket);
return id;
}

View File

@ -6,6 +6,7 @@
#include "typedefs.hpp"
#include "settings.hpp"
#include "util/Buffer.hpp"
#include "delegates.hpp"
namespace network {
using OnResponse = std::function<void(std::vector<char>)>;
@ -27,20 +28,27 @@ namespace network {
virtual void update() = 0;
};
class Socket {
enum class ConnectionState {
INITIAL, CONNECTING, CONNECTED, CLOSED
};
class Connection {
public:
virtual void connect(runnable callback) = 0;
virtual int recv(char* buffer, size_t length) = 0;
virtual int send(const char* buffer, size_t length) = 0;
virtual void close() = 0;
virtual bool isOpen() const = 0;
virtual int available() = 0;
virtual size_t getTotalUpload() const = 0;
virtual size_t getTotalDownload() const = 0;
virtual ConnectionState getState() const = 0;
};
class Network {
std::unique_ptr<Requests> requests;
std::unordered_map<u64id_t, std::shared_ptr<Socket>> connections;
std::unordered_map<u64id_t, std::shared_ptr<Connection>> connections;
u64id_t nextConnection = 1;
public:
Network(std::unique_ptr<Requests> requests);
@ -53,9 +61,9 @@ namespace network {
long maxSize=0
);
Socket* getConnection(u64id_t id) const;
[[nodiscard]] Connection* getConnection(u64id_t id) const;
u64id_t connect(const std::string& address, int port);
u64id_t connect(const std::string& address, int port, consumer<u64id_t> callback);
size_t getTotalUpload() const;
size_t getTotalDownload() const;

View File

@ -18,16 +18,6 @@ TEST(curltest, curltest) {
std::cout << value << std::endl;
}, [](auto){}
);
if (true) {
auto socket = network->getConnection(network->connect("google.com", 80));
const char* string = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
socket->send(string, strlen(string));
char data[1024];
int len = socket->recv(data, 1024);
std::cout << len << " " << std::string(data, len) << std::endl;
}
std::cout << "upload: " << network->getTotalUpload() << " B" << std::endl;
std::cout << "download: " << network->getTotalDownload() << " B" << std::endl;
}