diff --git a/.gitignore b/.gitignore index 6b7df28..55912ec 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ install_strip_local_manifest.txt install_strip_local_manifest.txt .vscode/ -my_own_redis \ No newline at end of file + +.venv/ +my_own_redis diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/CMakeLists.txt b/CMakeLists.txt index 623ca27..f40c703 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,14 +3,15 @@ cmake_minimum_required(VERSION 3.10) project(MyOwnRedis) -set(CMAKE_CXX_STANDART 17) -set(CMAKE_CXX_REQUIRED ON) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_REQUIRED_FLAGS -Wall -Wextra -Werror) set(SOURCES src/main.cpp src/utils.cpp src/config.cpp src/server.cpp + src/net_util.cpp ) add_executable(my_own_redis ${SOURCES}) diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/client.py b/client.py new file mode 100644 index 0000000..187b069 --- /dev/null +++ b/client.py @@ -0,0 +1,50 @@ +import socket +import struct +import argparse + + +def recv_exact(sock: socket.socket, n: int) -> bytes: + data = b'' + while len(data) < n: + chunk = sock.recv(n - len(data)) + if len(chunk) == 0: + raise ConnectionError('EOF from server') + data += chunk + return data + +def send_frame(sock: socket.socket, payload: bytes) -> None: + header = struct.pack('!I', len(payload)) + sock.sendall(header + payload) + +def recv_frame(sock: socket.socket) -> bytes: + header = recv_exact(sock, 4) + (length,) = struct.unpack('!I', header) + if length > 10_000_000: + raise ValueError('Message length is too long') + return recv_exact(sock, length) + +def main(): + parser = argparse.ArgumentParser(description='NOT(Redis) client') + parser.add_argument('-H', '--host', type=str, required=False, default='127.0.0.1', help='Server host') + parser.add_argument('-P', '--port', type=int, required=False, default=6379, help='Server port') + parser.add_argument('-M', '--message', type=str, required=False, default='hello', help='Message to send') + args = parser.parse_args() + + try: + with socket.create_connection((args.host, args.port)) as sock: + send_frame(sock, args.message.encode()) + response = recv_frame(sock) + print('Message sent:', args.message) + print('Server says:', response.decode("utf-8", errors='replace')) + except ConnectionRefusedError as e: + print('Connection refused by server:', e) + exit(1) + except ConnectionError as e: + print('Failed to connect to server:', e) + exit(1) + except Exception as e: + print('Unexpected error:', e) + exit(1) + +if __name__ == '__main__': + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4479991 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "own-redis-client" +version = "0.1.0" +description = "Fuck C++ for client I swear bruh" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [] + +[project.scripts] +client = "python client.py" diff --git a/src/constants.hpp b/src/constants.hpp new file mode 100644 index 0000000..f442ad8 --- /dev/null +++ b/src/constants.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include + +const uint32_t k_max_message_size = 4096; diff --git a/src/net_util.cpp b/src/net_util.cpp new file mode 100644 index 0000000..84046c4 --- /dev/null +++ b/src/net_util.cpp @@ -0,0 +1,37 @@ +#include "net_util.hpp" +#include +#include + + +int32_t read_full(int connectionfd, void* buf, size_t len) { + auto buffer = static_cast(buf); + while (len > 0) { + ssize_t rv = ::read(connectionfd, buffer, len); + if (rv == 0) { + return -1; + } + if (rv < 0) { + if (errno == EINTR) { + continue; + } + return -1; + } + buffer += rv; + len -= static_cast(rv); + } + return 0; +} + +int32_t write_all(int connectionfd, const void* buf, size_t len) { + auto* buffer = static_cast(buf); + while (len > 0) { + ssize_t rv = ::write(connectionfd, buffer, len); + if (rv <= 0) { + if (rv < 0 && errno == EINTR) continue; + return -1; + } + buffer += rv; + len -= static_cast(rv); + } + return 0; +} diff --git a/src/net_util.hpp b/src/net_util.hpp new file mode 100644 index 0000000..36596cc --- /dev/null +++ b/src/net_util.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include +#include + +int32_t read_full(int connectionfd, void* buf, size_t len); +int32_t write_all(int connectionfd, const void* buf, size_t len); diff --git a/src/server.cpp b/src/server.cpp index eabd1cd..c89648b 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -1,9 +1,20 @@ -#include "server.hpp" -#include -#include -#include -#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include + +#include "server.hpp" +#include "net_util.hpp" +#include "constants.hpp" +#include "utils.hpp" + Server::Server(const Config& config) : address(config.get_address()), port(config.get_port()), sockfd(-1) { @@ -43,18 +54,41 @@ void Server::setup() { std::cout << "Server started on " << address << ":" << port << "\n"; } -void Server::handle_connection(int connectionfd) { - char rbuf[64] = {}; - ssize_t n = read(connectionfd, rbuf, sizeof(rbuf) - 1); - if (n < 0) { - std::cerr << "Error reading client packet\n"; - return; +int32_t Server::handle_connection(int connectionfd) { + std::cout << "Waiting for data from " << connectionfd << "\n"; + + uint32_t len_net = 0; + errno = 0; + if (read_full(connectionfd, &len_net, sizeof(len_net)) != 0) { + print_error(errno == 0 ? "EOF from client" : "Error reading length"); + return -1; } - std::cout << "Client says: " << std::string(rbuf, n) << "\n"; - ssize_t written = write(connectionfd, "Hello, client!\n", 14); - if (written < 0) { - std::cerr << "Error writing to client\n"; + + uint32_t len = ntohl(len_net); + + if (len > k_max_message_size) { + print_error("Message length is too long"); + std::cerr << len << " > " << k_max_message_size << "\n"; + return -1; } + + std::string message(len, '\0'); + if (len > 0) { + if (read_full(connectionfd, message.data(), len) != 0) { + print_error("Error reading payload"); + return -1; + } + } + + std::cout << "Client says: " << message << "\n"; + + const std::string reply = "world"; + uint32_t reply_len_net = htonl(static_cast(reply.size())); + + if (write_all(connectionfd, &reply_len_net, sizeof(reply_len_net)) != 0) return -1; + if (write_all(connectionfd, reply.data(), reply.size()) != 0) return -1; + + return 0; } void Server::run() { @@ -67,7 +101,14 @@ void Server::run() { if (connectionfd < 0) { continue; } - handle_connection(connectionfd); + std::cout << "Accepted connection from " << connectionfd << "\n"; + while (true) { + int32_t err = handle_connection(connectionfd); + if (err) { + break; + } + } + std::cout << "Closing connection from " << connectionfd << "\n"; close(connectionfd); } } diff --git a/src/server.hpp b/src/server.hpp index e811495..790114f 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include "config.hpp" class Server { @@ -15,6 +16,6 @@ private: std::string address; int port; - void handle_connection(int connectionfd); + int32_t handle_connection(int connectionfd); void setup(); }; diff --git a/src/utils.cpp b/src/utils.cpp index 1d0d4c0..64421ff 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -7,3 +7,7 @@ void stop_program(int signum) { std::cout << "Exiting program...\n"; exit(signum); } + +void print_error(const std::string& msg) { + std::cerr << msg << ": " << std::strerror(errno) << "\n"; +} diff --git a/src/utils.hpp b/src/utils.hpp index 9554f86..6c81c7d 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include void stop_program(int signum); +void print_error(const std::string& msg); diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..65ca04f --- /dev/null +++ b/uv.lock @@ -0,0 +1,8 @@ +version = 1 +revision = 3 +requires-python = ">=3.13" + +[[package]] +name = "own-redis-client" +version = "0.1.0" +source = { virtual = "." }