complete simple connection implementation
This commit is contained in:
parent
3933baccd2
commit
fb0f4bff52
@ -39,13 +39,67 @@ static int l_get_binary(lua::State* L) {
|
||||
static int l_connect(lua::State* L) {
|
||||
std::string address = lua::require_string(L, 1);
|
||||
int port = lua::tointeger(L, 2);
|
||||
u64id_t id = engine->getNetwork().connect(address, port);
|
||||
lua::pushvalue(L, 3);
|
||||
auto callback = lua::create_lambda(L);
|
||||
u64id_t id = engine->getNetwork().connect(address, port, [callback](u64id_t id) {
|
||||
callback({id});
|
||||
});
|
||||
return lua::pushinteger(L, id);
|
||||
}
|
||||
|
||||
static int l_send(lua::State* L) {
|
||||
u64id_t id = lua::tointeger(L, 1);
|
||||
auto connection = engine->getNetwork().getConnection(id);
|
||||
|
||||
if (lua::istable(L, 2)) {
|
||||
lua::pushvalue(L, 2);
|
||||
size_t size = lua::objlen(L, 2);
|
||||
util::Buffer<char> buffer(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
lua::rawgeti(L, i + 1);
|
||||
buffer[i] = lua::tointeger(L, -1);
|
||||
lua::pop(L);
|
||||
}
|
||||
lua::pop(L);
|
||||
connection->send(buffer.data(), size);
|
||||
} else if (auto bytes = lua::touserdata<lua::LuaBytearray>(L, 2)) {
|
||||
connection->send(
|
||||
reinterpret_cast<char*>(bytes->data().data()), bytes->data().size()
|
||||
);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int l_recv(lua::State* L) {
|
||||
u64id_t id = lua::tointeger(L, 1);
|
||||
int length = lua::tointeger(L, 2);
|
||||
auto connection = engine->getNetwork().getConnection(id);
|
||||
util::Buffer<char> buffer(glm::min(length, connection->available()));
|
||||
|
||||
int size = connection->recv(buffer.data(), length);
|
||||
if (size == -1) {
|
||||
return 0;
|
||||
}
|
||||
if (lua::toboolean(L, 3)) {
|
||||
lua::createtable(L, size, 0);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
lua::pushinteger(L, buffer[i] & 0xFF);
|
||||
lua::rawseti(L, i+1);
|
||||
}
|
||||
} else {
|
||||
lua::newuserdata<lua::LuaBytearray>(L, size);
|
||||
auto bytearray = lua::touserdata<lua::LuaBytearray>(L, -1);
|
||||
bytearray->data().reserve(size);
|
||||
std::memcpy(bytearray->data().data(), buffer.data(), size);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
const luaL_Reg networklib[] = {
|
||||
{"get", lua::wrap<l_get>},
|
||||
{"get_binary", lua::wrap<l_get_binary>},
|
||||
{"__connect", lua::wrap<l_connect>},
|
||||
{"__send", lua::wrap<l_send>},
|
||||
{"__recv", lua::wrap<l_recv>},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
#include <stdexcept>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#ifdef _WIN32
|
||||
/// included in curl.h
|
||||
@ -24,9 +26,6 @@
|
||||
using SOCKET = int;
|
||||
#endif // _WIN32
|
||||
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
#include "debug/Logger.hpp"
|
||||
#include "util/stringutil.hpp"
|
||||
|
||||
@ -262,36 +261,87 @@ static std::string to_string(const addrinfo* addr) {
|
||||
return "";
|
||||
}
|
||||
|
||||
class SocketImpl : public Socket {
|
||||
class SocketConnection : public Connection {
|
||||
SOCKET descriptor;
|
||||
bool open = true;
|
||||
addrinfo* addr;
|
||||
size_t totalUpload = 0;
|
||||
size_t totalDownload = 0;
|
||||
public:
|
||||
SocketImpl(SOCKET descriptor, addrinfo* addr)
|
||||
: descriptor(descriptor), addr(addr) {
|
||||
}
|
||||
ConnectionState state = ConnectionState::INITIAL;
|
||||
std::unique_ptr<std::thread> thread = nullptr;
|
||||
std::vector<char> readBatch;
|
||||
util::Buffer<char> buffer;
|
||||
std::mutex mutex;
|
||||
|
||||
~SocketImpl() {
|
||||
closesocket(descriptor);
|
||||
void connectSocket() {
|
||||
state = ConnectionState::CONNECTING;
|
||||
logger.info() << "connecting to " << to_string(addr);
|
||||
int res = connectsocket(descriptor, addr->ai_addr, addr->ai_addrlen);
|
||||
if (res < 0) {
|
||||
auto error = handle_socket_error("Connect failed");
|
||||
closesocket(descriptor);
|
||||
freeaddrinfo(addr);
|
||||
state = ConnectionState::CLOSED;
|
||||
throw error;
|
||||
}
|
||||
logger.info() << "connected to " << to_string(addr);
|
||||
state = ConnectionState::CONNECTED;
|
||||
}
|
||||
public:
|
||||
SocketConnection(SOCKET descriptor, addrinfo* addr)
|
||||
: descriptor(descriptor), addr(addr), buffer(16'384) {}
|
||||
|
||||
~SocketConnection() {
|
||||
if (state != ConnectionState::CLOSED) {
|
||||
shutdown(descriptor, 2);
|
||||
closesocket(descriptor);
|
||||
}
|
||||
thread->join();
|
||||
freeaddrinfo(addr);
|
||||
}
|
||||
|
||||
void connect(runnable callback) override {
|
||||
thread = std::make_unique<std::thread>([this, callback]() {
|
||||
connectSocket();
|
||||
callback();
|
||||
while (state == ConnectionState::CONNECTED) {
|
||||
int size = recvsocket(descriptor, buffer.data(), buffer.size());
|
||||
if (size == 0) {
|
||||
logger.info() << "closed connection " << to_string(addr);
|
||||
closesocket(descriptor);
|
||||
state = ConnectionState::CLOSED;
|
||||
break;
|
||||
} else if (size < 0) {
|
||||
logger.info() << "an error ocurred while receiving from "
|
||||
<< to_string(addr);
|
||||
auto error = handle_socket_error("recv(...) error");
|
||||
closesocket(descriptor);
|
||||
state = ConnectionState::CLOSED;
|
||||
logger.error() << error.what();
|
||||
break;
|
||||
}
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
readBatch.emplace_back(buffer[i]);
|
||||
}
|
||||
totalDownload += size;
|
||||
}
|
||||
logger.info() << "read " << size << " bytes from " << to_string(addr);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int recv(char* buffer, size_t length) override {
|
||||
int len = recvsocket(descriptor, buffer, length);
|
||||
if (len == 0) {
|
||||
int err = errno;
|
||||
close();
|
||||
throw std::runtime_error(
|
||||
"Read failed [errno=" + std::to_string(err) +
|
||||
"]: " + std::string(strerror(err))
|
||||
);
|
||||
} else if (len == -1) {
|
||||
return 0;
|
||||
std::lock_guard lock(mutex);
|
||||
|
||||
if (state != ConnectionState::CONNECTED && readBatch.empty()) {
|
||||
return -1;
|
||||
}
|
||||
totalDownload += len;
|
||||
return len;
|
||||
int size = std::min(readBatch.size(), length);
|
||||
std::memcpy(buffer, readBatch.data(), size);
|
||||
readBatch.erase(readBatch.begin(), readBatch.begin() + size);
|
||||
return size;
|
||||
}
|
||||
|
||||
int send(const char* buffer, size_t length) override {
|
||||
@ -308,13 +358,17 @@ public:
|
||||
return len;
|
||||
}
|
||||
|
||||
void close() override {
|
||||
closesocket(descriptor);
|
||||
open = false;
|
||||
int available() override {
|
||||
std::lock_guard lock(mutex);
|
||||
return readBatch.size();
|
||||
}
|
||||
|
||||
bool isOpen() const override {
|
||||
return open;
|
||||
void close() override {
|
||||
if (state != ConnectionState::CLOSED) {
|
||||
shutdown(descriptor, 2);
|
||||
closesocket(descriptor);
|
||||
}
|
||||
thread->join();
|
||||
}
|
||||
|
||||
size_t getTotalUpload() const override {
|
||||
@ -325,8 +379,8 @@ public:
|
||||
return totalDownload;
|
||||
}
|
||||
|
||||
static std::shared_ptr<SocketImpl> connect(
|
||||
const std::string& address, int port
|
||||
static std::shared_ptr<SocketConnection> connect(
|
||||
const std::string& address, int port, runnable callback
|
||||
) {
|
||||
addrinfo hints {};
|
||||
|
||||
@ -346,21 +400,18 @@ public:
|
||||
freeaddrinfo(addrinfo);
|
||||
throw std::runtime_error("Could not create socket");
|
||||
}
|
||||
int res = connectsocket(descriptor, addrinfo->ai_addr, addrinfo->ai_addrlen);
|
||||
if (res < 0) {
|
||||
auto error = handle_socket_error("Connect failed");
|
||||
closesocket(descriptor);
|
||||
freeaddrinfo(addrinfo);
|
||||
throw error;
|
||||
}
|
||||
logger.info() << "connected to " << address << " ["
|
||||
<< to_string(addrinfo) << ":" << port << "]";
|
||||
return std::make_shared<SocketImpl>(descriptor, addrinfo);
|
||||
auto socket = std::make_shared<SocketConnection>(descriptor, addrinfo);
|
||||
socket->connect(std::move(callback));
|
||||
return socket;
|
||||
}
|
||||
|
||||
ConnectionState getState() const override {
|
||||
return state;
|
||||
}
|
||||
};
|
||||
|
||||
Network::Network(std::unique_ptr<Requests> requests)
|
||||
: requests(std::move(requests)) {
|
||||
: requests(std::move(requests)) {
|
||||
}
|
||||
|
||||
Network::~Network() = default;
|
||||
@ -374,7 +425,7 @@ void Network::get(
|
||||
requests->get(url, onResponse, onReject, maxSize);
|
||||
}
|
||||
|
||||
Socket* Network::getConnection(u64id_t id) const {
|
||||
Connection* Network::getConnection(u64id_t id) const {
|
||||
const auto& found = connections.find(id);
|
||||
if (found == connections.end()) {
|
||||
return nullptr;
|
||||
@ -382,9 +433,11 @@ Socket* Network::getConnection(u64id_t id) const {
|
||||
return found->second.get();
|
||||
}
|
||||
|
||||
u64id_t Network::connect(const std::string& address, int port) {
|
||||
auto socket = SocketImpl::connect(address, port);
|
||||
u64id_t Network::connect(const std::string& address, int port, consumer<u64id_t> callback) {
|
||||
u64id_t id = nextConnection++;
|
||||
auto socket = SocketConnection::connect(address, port, [id, callback]() {
|
||||
callback(id);
|
||||
});
|
||||
connections[id] = std::move(socket);
|
||||
return id;
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include "typedefs.hpp"
|
||||
#include "settings.hpp"
|
||||
#include "util/Buffer.hpp"
|
||||
#include "delegates.hpp"
|
||||
|
||||
namespace network {
|
||||
using OnResponse = std::function<void(std::vector<char>)>;
|
||||
@ -27,20 +28,27 @@ namespace network {
|
||||
virtual void update() = 0;
|
||||
};
|
||||
|
||||
class Socket {
|
||||
enum class ConnectionState {
|
||||
INITIAL, CONNECTING, CONNECTED, CLOSED
|
||||
};
|
||||
|
||||
class Connection {
|
||||
public:
|
||||
virtual void connect(runnable callback) = 0;
|
||||
virtual int recv(char* buffer, size_t length) = 0;
|
||||
virtual int send(const char* buffer, size_t length) = 0;
|
||||
virtual void close() = 0;
|
||||
virtual bool isOpen() const = 0;
|
||||
virtual int available() = 0;
|
||||
|
||||
virtual size_t getTotalUpload() const = 0;
|
||||
virtual size_t getTotalDownload() const = 0;
|
||||
|
||||
virtual ConnectionState getState() const = 0;
|
||||
};
|
||||
|
||||
class Network {
|
||||
std::unique_ptr<Requests> requests;
|
||||
std::unordered_map<u64id_t, std::shared_ptr<Socket>> connections;
|
||||
std::unordered_map<u64id_t, std::shared_ptr<Connection>> connections;
|
||||
u64id_t nextConnection = 1;
|
||||
public:
|
||||
Network(std::unique_ptr<Requests> requests);
|
||||
@ -53,9 +61,9 @@ namespace network {
|
||||
long maxSize=0
|
||||
);
|
||||
|
||||
Socket* getConnection(u64id_t id) const;
|
||||
[[nodiscard]] Connection* getConnection(u64id_t id) const;
|
||||
|
||||
u64id_t connect(const std::string& address, int port);
|
||||
u64id_t connect(const std::string& address, int port, consumer<u64id_t> callback);
|
||||
|
||||
size_t getTotalUpload() const;
|
||||
size_t getTotalDownload() const;
|
||||
|
||||
@ -18,16 +18,6 @@ TEST(curltest, curltest) {
|
||||
std::cout << value << std::endl;
|
||||
}, [](auto){}
|
||||
);
|
||||
if (true) {
|
||||
auto socket = network->getConnection(network->connect("google.com", 80));
|
||||
const char* string = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
|
||||
socket->send(string, strlen(string));
|
||||
char data[1024];
|
||||
|
||||
int len = socket->recv(data, 1024);
|
||||
std::cout << len << " " << std::string(data, len) << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "upload: " << network->getTotalUpload() << " B" << std::endl;
|
||||
std::cout << "download: " << network->getTotalDownload() << " B" << std::endl;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user