diff --git a/README.md b/README.md index e69de29..61a2114 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,6 @@ +# (not) a Redis + +A toy implementation of a Redis-like server in C++. +This is a learning project for me to understand how Redis works internally. It is not intended for production use. + +Thanks to [Build-your-own-X](https://github.com/codecrafters-io/build-your-own-x) repository for helping me get started with this project! diff --git a/src/net_util.cpp b/src/net_util.cpp index 84046c4..527e3b3 100644 --- a/src/net_util.cpp +++ b/src/net_util.cpp @@ -1,8 +1,26 @@ #include "net_util.hpp" #include #include +#include +#include +void fd_set_nb(int fd) { + errno = 0; + int flags = fcntl(fd, F_GETFL, 0); + if (errno) { + // handle error + return; + } + flags |= O_NONBLOCK; + + errno = 0; + (void)fcntl(fd, F_SETFL, flags); + if (errno) { + // handle error + } +} + int32_t read_full(int connectionfd, void* buf, size_t len) { auto buffer = static_cast(buf); while (len > 0) { diff --git a/src/net_util.hpp b/src/net_util.hpp index 36596cc..25f3f2f 100644 --- a/src/net_util.hpp +++ b/src/net_util.hpp @@ -5,3 +5,4 @@ int32_t read_full(int connectionfd, void* buf, size_t len); int32_t write_all(int connectionfd, const void* buf, size_t len); +void fd_set_nb(int fd); diff --git a/src/server.cpp b/src/server.cpp index 31e3d41..625241a 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "server.hpp" #include "net_util.hpp" @@ -25,6 +27,9 @@ Server::~Server() { if (sockfd != -1) { close(sockfd); } + for (auto const& [fd, conn] : fd2conn) { + delete conn; + } } void Server::setup() { @@ -51,71 +56,206 @@ void Server::setup() { if (listen(sockfd, SOMAXCONN) < 0) { throw std::runtime_error("Error listening for connections"); } + + fd_set_nb(sockfd); Logger::log_info("Server started on " + address + ":" + std::to_string(port)); } -int32_t Server::handle_connection(int connectionfd) { - Logger::log_debug("Waiting for data from fd=" + std::to_string(connectionfd)); - - uint32_t len_net = 0; - errno = 0; - if (read_full(connectionfd, &len_net, sizeof(len_net)) != 0) { - if (errno == 0) { - Logger::log_info("Client disconnected (EOF) from fd=" + std::to_string(connectionfd)); - } else { - Logger::log_error("Error reading length from fd=" + std::to_string(connectionfd) + ": " + std::strerror(errno)); - } - return -1; +void Server::accept_new_connection() { + struct sockaddr_in client_addr = {}; + socklen_t addrlen = sizeof(client_addr); + int connectionfd = accept(sockfd, (struct sockaddr *)&client_addr, &addrlen); + if (connectionfd < 0) { + Logger::log_error("accept() error: " + std::string(std::strerror(errno))); + return; } + fd_set_nb(connectionfd); + + Connection* conn = new Connection(); + conn->connectionfd = connectionfd; + conn->state = STATE_REQ; + + fd2conn[connectionfd] = conn; + + struct pollfd pfd = {}; + pfd.fd = connectionfd; + pfd.events = POLLIN; + pfd.revents = 0; + poll_args.push_back(pfd); + + Logger::log_info("Accepted connection from " + std::string(inet_ntoa(client_addr.sin_addr)) + ":" + std::to_string(ntohs(client_addr.sin_port)) + " (fd=" + std::to_string(connectionfd) + ")"); +} + +void Server::state_req(Connection* conn) { + while (true) { + char buf[1024]; + ssize_t rv = read(conn->connectionfd, buf, sizeof(buf)); + if (rv < 0) { + if (errno == EAGAIN) { + break; + } + if (errno == EINTR) { + continue; + } + conn->state = STATE_END; + return; + } + if (rv == 0) { + conn->state = STATE_END; + return; + } + + conn->incoming.insert(conn->incoming.end(), buf, buf + rv); + } + + while (conn->state == STATE_REQ) { + if (conn->incoming.size() < 4) { + break; + } + + uint32_t len_net; + memcpy(&len_net, conn->incoming.data(), 4); + uint32_t len = ntohl(len_net); + + if (len > k_max_message_size) { + Logger::log_error("Message too long"); + conn->state = STATE_END; + return; + } + + if (conn->incoming.size() < 4 + len) { + break; + } + + if (parse_and_execute(conn) != 0) { + conn->state = STATE_END; + return; + } + + conn->incoming.erase(conn->incoming.begin(), conn->incoming.begin() + 4 + len); + } +} + +void Server::state_res(Connection* conn) { + while (!conn->outgoing.empty()) { + ssize_t rv = write(conn->connectionfd, conn->outgoing.data(), conn->outgoing.size()); + if (rv < 0) { + if (errno == EAGAIN) { + break; + } + if (errno == EINTR) { + continue; + } + conn->state = STATE_END; + return; + } + + conn->outgoing.erase(conn->outgoing.begin(), conn->outgoing.begin() + rv); + } + + if (conn->outgoing.empty()) { + conn->state = STATE_REQ; + } +} + +int32_t Server::parse_and_execute(Connection* conn) { + uint32_t len_net; + memcpy(&len_net, conn->incoming.data(), 4); uint32_t len = ntohl(len_net); - - if (len > k_max_message_size) { - Logger::log_error("Message length is too long: " + std::to_string(len) + " (max: " + std::to_string(k_max_message_size) + ")"); - return -1; - } - - std::string message(len, '\0'); - if (len > 0) { - if (read_full(connectionfd, message.data(), len) != 0) { - Logger::log_error("Error reading payload from fd=" + std::to_string(connectionfd)); - return -1; - } - } - - Logger::log_info("Client (fd=" + std::to_string(connectionfd) + ") says: " + message); - + + std::string message(conn->incoming.begin() + 4, conn->incoming.begin() + 4 + len); + Logger::log_info("Client (fd=" + std::to_string(conn->connectionfd) + ") says: " + message); + const std::string reply = "world"; - uint32_t reply_len_net = htonl(static_cast(reply.size())); - - if (write_all(connectionfd, &reply_len_net, sizeof(reply_len_net)) != 0) return -1; - if (write_all(connectionfd, reply.data(), reply.size()) != 0) return -1; - + uint32_t reply_len = static_cast(reply.size()); + uint32_t reply_len_net = htonl(reply_len); + + const char* header_ptr = reinterpret_cast(&reply_len_net); + conn->outgoing.insert(conn->outgoing.end(), header_ptr, header_ptr + 4); + + conn->outgoing.insert(conn->outgoing.end(), reply.begin(), reply.end()); + + conn->state = STATE_RES; return 0; } +void Server::connection_io(Connection* conn) { + if (conn->state == STATE_REQ) { + state_req(conn); + } else if (conn->state == STATE_RES) { + state_res(conn); + } +} + + void Server::run() { setup(); + struct pollfd pfd_listen = {sockfd, POLLIN, 0}; + poll_args.push_back(pfd_listen); + while (true) { - struct sockaddr_in client_addr = {}; - socklen_t addrlen = sizeof(client_addr); - int connectionfd = accept(sockfd, (struct sockaddr *)&client_addr, &addrlen); - if (connectionfd < 0) { - Logger::log_error("accept() error: " + std::string(std::strerror(errno))); - continue; - } - - Logger::log_info("Accepted connection from " + std::string(inet_ntoa(client_addr.sin_addr)) + ":" + std::to_string(ntohs(client_addr.sin_port)) + " (fd=" + std::to_string(connectionfd) + ")"); - - while (true) { - int32_t err = handle_connection(connectionfd); - if (err) { - break; + // Prepare events used in poll + for (auto& pfd : poll_args) { + if (pfd.fd == sockfd) { + pfd.events = POLLIN; + continue; + } + + auto it = fd2conn.find(pfd.fd); + if (it != fd2conn.end()) { + Connection* conn = it->second; + if (conn->state == STATE_REQ) { + pfd.events = POLLIN; + } else if (conn->state == STATE_RES) { + pfd.events = POLLOUT; + } + } + } + + int rv = poll(poll_args.data(), (nfds_t)poll_args.size(), 1000); + if (rv < 0) { + Logger::log_error("poll error"); + break; + } + + for (size_t i = 0; i < poll_args.size(); ++i) { + if (poll_args[i].revents) { + int fd = poll_args[i].fd; + if (fd == sockfd) { + accept_new_connection(); + } else { + auto it = fd2conn.find(fd); + if (it != fd2conn.end()) { + connection_io(it->second); + } + } + } + } + + for (size_t i = 0; i < poll_args.size(); ) { + int fd = poll_args[i].fd; + if (fd == sockfd) { + ++i; + continue; + } + + auto it = fd2conn.find(fd); + if (it == fd2conn.end() || it->second->state == STATE_END) { + if (it != fd2conn.end()) { + close(fd); + delete it->second; + fd2conn.erase(it); + Logger::log_info("Closing connection fd=" + std::to_string(fd)); + } + + poll_args[i] = poll_args.back(); + poll_args.pop_back(); + } else { + ++i; } } - Logger::log_info("Closing connection from " + std::to_string(connectionfd)); - close(connectionfd); } } diff --git a/src/server.hpp b/src/server.hpp index 790114f..64535fd 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -3,6 +3,22 @@ #include #include #include "config.hpp" +#include +#include +#include + +enum { + STATE_REQ = 0, + STATE_RES = 1, + STATE_END = 2, +}; + +struct Connection { + int connectionfd = -1; + uint32_t state = STATE_REQ; + std::vector incoming; + std::vector outgoing; +}; class Server { public: @@ -16,6 +32,13 @@ private: std::string address; int port; - int32_t handle_connection(int connectionfd); + std::vector poll_args; + std::map fd2conn; + void setup(); + void accept_new_connection(); + void connection_io(Connection* conn); + void state_req(Connection* conn); + void state_res(Connection* conn); + int32_t parse_and_execute(Connection* conn); };