diff --git a/dev/tests/network_tcp.lua b/dev/tests/network_tcp.lua index eb6b1880..37906547 100644 --- a/dev/tests/network_tcp.lua +++ b/dev/tests/network_tcp.lua @@ -1,5 +1,5 @@ for i=1,3 do - print(string.format("iteration %s", i + 1)) + print(string.format("iteration %s", i)) local text = "" local complete = false diff --git a/dev/tests/network_udp.lua b/dev/tests/network_udp.lua new file mode 100644 index 00000000..975321a4 --- /dev/null +++ b/dev/tests/network_udp.lua @@ -0,0 +1,35 @@ +math.randomseed(43172) +for i = 1, 15 do + debug.log(string.format("iteration %s", i)) + local complete = false + + local server = network.udp_open(8645 + i, function (address, port, data, srv) + debug.log(string.format("server received %s byte(s) from %s:%s", #data, address, port)) + srv:send(address, port, "pong") + end) + + app.tick() + network.udp_connect("localhost", 8645 + i, function (data) + debug.log(string.format("client received %s byte(s) from server", #data)) + complete = true + end, function (socket) + debug.log("udp socket opened") + start_coroutine(function() + debug.log("udp data-sender started") + for k = 1, 15 do + local payload = "" + for j = 1, 16 do + payload = payload .. math.random(0, 9) + end + socket:send(payload) + debug.log(string.format("sent packet %s (%s bytes)", k, #payload)) + coroutine.yield() + end + app.sleep_until(function () return complete end, nil, 5) + socket:close() + end, "udp-data-sender") + end) + + app.sleep_until(function () return complete end, nil, 5) + server:close() +end diff --git a/res/scripts/classes.lua b/res/scripts/classes.lua index f5457ef5..49d11268 100644 --- a/res/scripts/classes.lua +++ b/res/scripts/classes.lua @@ -144,7 +144,11 @@ network.udp_connect = function (address, port, datagramHandler, openCallback) socket.id = network.__connect_udp(address, port) _udp_client_datagram_callbacks[socket.id] = datagramHandler - _udp_client_open_callbacks[socket.id] = openCallback + if openCallback then + _udp_client_open_callbacks[socket.id] = function() + openCallback(socket) + end + end return socket end @@ -254,9 +258,15 @@ network.__process_events = function() end elseif etype == DATAGRAM then if side == ON_CLIENT then - _udp_client_datagram_callbacks[cid](data) + local callback = _udp_client_datagram_callbacks[cid] + if callback then + callback(data) + end elseif side == ON_SERVER then - _udp_server_callbacks[sid](addr, port, data) + local callback = _udp_server_callbacks[sid] + if callback then + callback(addr, port, data) + end end elseif etype == RESPONSE then if event[2] / 100 == 2 then diff --git a/src/network/Network.cpp b/src/network/Network.cpp index 7d33536c..b3ff07cf 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -620,6 +620,26 @@ public: } }; +static sockaddr_in resolve_address_dgram(const std::string& address, int port) { + sockaddr_in serverAddr{}; + addrinfo hints {}; + + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_DGRAM; + + addrinfo* addrinfo = nullptr; + if (int res = getaddrinfo( + address.c_str(), nullptr, &hints, &addrinfo + )) { + throw std::runtime_error(gai_strerror(res)); + } + + std::memcpy(&serverAddr, addrinfo->ai_addr, sizeof(sockaddr_in)); + serverAddr.sin_port = htons(port); + freeaddrinfo(addrinfo); + return serverAddr; +} + class SocketUdpConnection : public UdpConnection { u64id_t id; SOCKET descriptor; @@ -652,13 +672,7 @@ public: throw std::runtime_error("could not create udp socket"); } - sockaddr_in serverAddr{}; - serverAddr.sin_family = AF_INET; - if (inet_pton(AF_INET, address.c_str(), &serverAddr.sin_addr) <= 0) { - closesocket(descriptor); - throw std::runtime_error("invalid udp address: " + address); - } - serverAddr.sin_port = htons(port); + sockaddr_in serverAddr = resolve_address_dgram(address, port); if (::connect(descriptor, (sockaddr*)&serverAddr, sizeof(serverAddr)) < 0) { auto err = handle_socket_error("udp connect failed"); @@ -683,6 +697,7 @@ public: while (open) { int size = recv(descriptor, buffer.data(), buffer.size(), 0); if (size <= 0) { + logger.error() <(&client), sizeof(client)); + sockaddr_in client = resolve_address_dgram(addr, port); + if (sendto(descriptor, buffer, length, 0, + reinterpret_cast(&client), sizeof(client)) < 0) { + logger.error() << handle_socket_error("sendto").what(); + } } void close() override {