Event loop for connections implemented + updated README
This commit is contained in:
parent
c0228b0c12
commit
e8bb171df2
@ -0,0 +1,6 @@
|
||||
# (not) a Redis
|
||||
|
||||
A toy implementation of a Redis-like server in C++.
|
||||
This is a learning project for me to understand how Redis works internally. It is not intended for production use.
|
||||
|
||||
Thanks to [Build-your-own-X](https://github.com/codecrafters-io/build-your-own-x) repository for helping me get started with this project!
|
||||
@ -1,8 +1,26 @@
|
||||
#include "net_util.hpp"
|
||||
#include <assert.h>
|
||||
#include <string>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
|
||||
|
||||
void fd_set_nb(int fd) {
|
||||
errno = 0;
|
||||
int flags = fcntl(fd, F_GETFL, 0);
|
||||
if (errno) {
|
||||
// handle error
|
||||
return;
|
||||
}
|
||||
flags |= O_NONBLOCK;
|
||||
|
||||
errno = 0;
|
||||
(void)fcntl(fd, F_SETFL, flags);
|
||||
if (errno) {
|
||||
// handle error
|
||||
}
|
||||
}
|
||||
|
||||
int32_t read_full(int connectionfd, void* buf, size_t len) {
|
||||
auto buffer = static_cast<char*>(buf);
|
||||
while (len > 0) {
|
||||
|
||||
@ -5,3 +5,4 @@
|
||||
|
||||
int32_t read_full(int connectionfd, void* buf, size_t len);
|
||||
int32_t write_all(int connectionfd, const void* buf, size_t len);
|
||||
void fd_set_nb(int fd);
|
||||
|
||||
238
src/server.cpp
238
src/server.cpp
@ -9,6 +9,8 @@
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/ip.h>
|
||||
#include <iostream>
|
||||
#include <fcntl.h>
|
||||
#include <poll.h>
|
||||
|
||||
#include "server.hpp"
|
||||
#include "net_util.hpp"
|
||||
@ -25,6 +27,9 @@ Server::~Server() {
|
||||
if (sockfd != -1) {
|
||||
close(sockfd);
|
||||
}
|
||||
for (auto const& [fd, conn] : fd2conn) {
|
||||
delete conn;
|
||||
}
|
||||
}
|
||||
|
||||
void Server::setup() {
|
||||
@ -51,71 +56,206 @@ void Server::setup() {
|
||||
if (listen(sockfd, SOMAXCONN) < 0) {
|
||||
throw std::runtime_error("Error listening for connections");
|
||||
}
|
||||
|
||||
fd_set_nb(sockfd);
|
||||
|
||||
Logger::log_info("Server started on " + address + ":" + std::to_string(port));
|
||||
}
|
||||
|
||||
int32_t Server::handle_connection(int connectionfd) {
|
||||
Logger::log_debug("Waiting for data from fd=" + std::to_string(connectionfd));
|
||||
|
||||
uint32_t len_net = 0;
|
||||
errno = 0;
|
||||
if (read_full(connectionfd, &len_net, sizeof(len_net)) != 0) {
|
||||
if (errno == 0) {
|
||||
Logger::log_info("Client disconnected (EOF) from fd=" + std::to_string(connectionfd));
|
||||
} else {
|
||||
Logger::log_error("Error reading length from fd=" + std::to_string(connectionfd) + ": " + std::strerror(errno));
|
||||
}
|
||||
return -1;
|
||||
void Server::accept_new_connection() {
|
||||
struct sockaddr_in client_addr = {};
|
||||
socklen_t addrlen = sizeof(client_addr);
|
||||
int connectionfd = accept(sockfd, (struct sockaddr *)&client_addr, &addrlen);
|
||||
if (connectionfd < 0) {
|
||||
Logger::log_error("accept() error: " + std::string(std::strerror(errno)));
|
||||
return;
|
||||
}
|
||||
|
||||
fd_set_nb(connectionfd);
|
||||
|
||||
Connection* conn = new Connection();
|
||||
conn->connectionfd = connectionfd;
|
||||
conn->state = STATE_REQ;
|
||||
|
||||
fd2conn[connectionfd] = conn;
|
||||
|
||||
struct pollfd pfd = {};
|
||||
pfd.fd = connectionfd;
|
||||
pfd.events = POLLIN;
|
||||
pfd.revents = 0;
|
||||
poll_args.push_back(pfd);
|
||||
|
||||
Logger::log_info("Accepted connection from " + std::string(inet_ntoa(client_addr.sin_addr)) + ":" + std::to_string(ntohs(client_addr.sin_port)) + " (fd=" + std::to_string(connectionfd) + ")");
|
||||
}
|
||||
|
||||
void Server::state_req(Connection* conn) {
|
||||
while (true) {
|
||||
char buf[1024];
|
||||
ssize_t rv = read(conn->connectionfd, buf, sizeof(buf));
|
||||
if (rv < 0) {
|
||||
if (errno == EAGAIN) {
|
||||
break;
|
||||
}
|
||||
if (errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
conn->state = STATE_END;
|
||||
return;
|
||||
}
|
||||
if (rv == 0) {
|
||||
conn->state = STATE_END;
|
||||
return;
|
||||
}
|
||||
|
||||
conn->incoming.insert(conn->incoming.end(), buf, buf + rv);
|
||||
}
|
||||
|
||||
while (conn->state == STATE_REQ) {
|
||||
if (conn->incoming.size() < 4) {
|
||||
break;
|
||||
}
|
||||
|
||||
uint32_t len_net;
|
||||
memcpy(&len_net, conn->incoming.data(), 4);
|
||||
uint32_t len = ntohl(len_net);
|
||||
|
||||
if (len > k_max_message_size) {
|
||||
Logger::log_error("Message too long");
|
||||
conn->state = STATE_END;
|
||||
return;
|
||||
}
|
||||
|
||||
if (conn->incoming.size() < 4 + len) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (parse_and_execute(conn) != 0) {
|
||||
conn->state = STATE_END;
|
||||
return;
|
||||
}
|
||||
|
||||
conn->incoming.erase(conn->incoming.begin(), conn->incoming.begin() + 4 + len);
|
||||
}
|
||||
}
|
||||
|
||||
void Server::state_res(Connection* conn) {
|
||||
while (!conn->outgoing.empty()) {
|
||||
ssize_t rv = write(conn->connectionfd, conn->outgoing.data(), conn->outgoing.size());
|
||||
if (rv < 0) {
|
||||
if (errno == EAGAIN) {
|
||||
break;
|
||||
}
|
||||
if (errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
conn->state = STATE_END;
|
||||
return;
|
||||
}
|
||||
|
||||
conn->outgoing.erase(conn->outgoing.begin(), conn->outgoing.begin() + rv);
|
||||
}
|
||||
|
||||
if (conn->outgoing.empty()) {
|
||||
conn->state = STATE_REQ;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t Server::parse_and_execute(Connection* conn) {
|
||||
uint32_t len_net;
|
||||
memcpy(&len_net, conn->incoming.data(), 4);
|
||||
uint32_t len = ntohl(len_net);
|
||||
|
||||
if (len > k_max_message_size) {
|
||||
Logger::log_error("Message length is too long: " + std::to_string(len) + " (max: " + std::to_string(k_max_message_size) + ")");
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string message(len, '\0');
|
||||
if (len > 0) {
|
||||
if (read_full(connectionfd, message.data(), len) != 0) {
|
||||
Logger::log_error("Error reading payload from fd=" + std::to_string(connectionfd));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
Logger::log_info("Client (fd=" + std::to_string(connectionfd) + ") says: " + message);
|
||||
|
||||
|
||||
std::string message(conn->incoming.begin() + 4, conn->incoming.begin() + 4 + len);
|
||||
Logger::log_info("Client (fd=" + std::to_string(conn->connectionfd) + ") says: " + message);
|
||||
|
||||
const std::string reply = "world";
|
||||
uint32_t reply_len_net = htonl(static_cast<uint32_t>(reply.size()));
|
||||
|
||||
if (write_all(connectionfd, &reply_len_net, sizeof(reply_len_net)) != 0) return -1;
|
||||
if (write_all(connectionfd, reply.data(), reply.size()) != 0) return -1;
|
||||
|
||||
uint32_t reply_len = static_cast<uint32_t>(reply.size());
|
||||
uint32_t reply_len_net = htonl(reply_len);
|
||||
|
||||
const char* header_ptr = reinterpret_cast<const char*>(&reply_len_net);
|
||||
conn->outgoing.insert(conn->outgoing.end(), header_ptr, header_ptr + 4);
|
||||
|
||||
conn->outgoing.insert(conn->outgoing.end(), reply.begin(), reply.end());
|
||||
|
||||
conn->state = STATE_RES;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Server::connection_io(Connection* conn) {
|
||||
if (conn->state == STATE_REQ) {
|
||||
state_req(conn);
|
||||
} else if (conn->state == STATE_RES) {
|
||||
state_res(conn);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Server::run() {
|
||||
setup();
|
||||
|
||||
struct pollfd pfd_listen = {sockfd, POLLIN, 0};
|
||||
poll_args.push_back(pfd_listen);
|
||||
|
||||
while (true) {
|
||||
struct sockaddr_in client_addr = {};
|
||||
socklen_t addrlen = sizeof(client_addr);
|
||||
int connectionfd = accept(sockfd, (struct sockaddr *)&client_addr, &addrlen);
|
||||
if (connectionfd < 0) {
|
||||
Logger::log_error("accept() error: " + std::string(std::strerror(errno)));
|
||||
continue;
|
||||
}
|
||||
|
||||
Logger::log_info("Accepted connection from " + std::string(inet_ntoa(client_addr.sin_addr)) + ":" + std::to_string(ntohs(client_addr.sin_port)) + " (fd=" + std::to_string(connectionfd) + ")");
|
||||
|
||||
while (true) {
|
||||
int32_t err = handle_connection(connectionfd);
|
||||
if (err) {
|
||||
break;
|
||||
// Prepare events used in poll
|
||||
for (auto& pfd : poll_args) {
|
||||
if (pfd.fd == sockfd) {
|
||||
pfd.events = POLLIN;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = fd2conn.find(pfd.fd);
|
||||
if (it != fd2conn.end()) {
|
||||
Connection* conn = it->second;
|
||||
if (conn->state == STATE_REQ) {
|
||||
pfd.events = POLLIN;
|
||||
} else if (conn->state == STATE_RES) {
|
||||
pfd.events = POLLOUT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int rv = poll(poll_args.data(), (nfds_t)poll_args.size(), 1000);
|
||||
if (rv < 0) {
|
||||
Logger::log_error("poll error");
|
||||
break;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < poll_args.size(); ++i) {
|
||||
if (poll_args[i].revents) {
|
||||
int fd = poll_args[i].fd;
|
||||
if (fd == sockfd) {
|
||||
accept_new_connection();
|
||||
} else {
|
||||
auto it = fd2conn.find(fd);
|
||||
if (it != fd2conn.end()) {
|
||||
connection_io(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < poll_args.size(); ) {
|
||||
int fd = poll_args[i].fd;
|
||||
if (fd == sockfd) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = fd2conn.find(fd);
|
||||
if (it == fd2conn.end() || it->second->state == STATE_END) {
|
||||
if (it != fd2conn.end()) {
|
||||
close(fd);
|
||||
delete it->second;
|
||||
fd2conn.erase(it);
|
||||
Logger::log_info("Closing connection fd=" + std::to_string(fd));
|
||||
}
|
||||
|
||||
poll_args[i] = poll_args.back();
|
||||
poll_args.pop_back();
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
Logger::log_info("Closing connection from " + std::to_string(connectionfd));
|
||||
close(connectionfd);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,6 +3,22 @@
|
||||
#include <string>
|
||||
#include <stdint.h>
|
||||
#include "config.hpp"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <poll.h>
|
||||
|
||||
enum {
|
||||
STATE_REQ = 0,
|
||||
STATE_RES = 1,
|
||||
STATE_END = 2,
|
||||
};
|
||||
|
||||
struct Connection {
|
||||
int connectionfd = -1;
|
||||
uint32_t state = STATE_REQ;
|
||||
std::vector<uint8_t> incoming;
|
||||
std::vector<uint8_t> outgoing;
|
||||
};
|
||||
|
||||
class Server {
|
||||
public:
|
||||
@ -16,6 +32,13 @@ private:
|
||||
std::string address;
|
||||
int port;
|
||||
|
||||
int32_t handle_connection(int connectionfd);
|
||||
std::vector<struct pollfd> poll_args;
|
||||
std::map<int, Connection*> fd2conn;
|
||||
|
||||
void setup();
|
||||
void accept_new_connection();
|
||||
void connection_io(Connection* conn);
|
||||
void state_req(Connection* conn);
|
||||
void state_res(Connection* conn);
|
||||
int32_t parse_and_execute(Connection* conn);
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user