diff --git a/deps/BeamMP-Protocol b/deps/BeamMP-Protocol index b2622c3..51c33c4 160000 --- a/deps/BeamMP-Protocol +++ b/deps/BeamMP-Protocol @@ -1 +1 @@ -Subproject commit b2622c3fac32aa8ba3b602b7735b201ff3311eca +Subproject commit 51c33c4002152463fc600085a1b593fc35465a2a diff --git a/include/Network.h b/include/Network.h index c8548aa..80625f9 100644 --- a/include/Network.h +++ b/include/Network.h @@ -19,9 +19,17 @@ using namespace boost::asio; struct Packet { bmp::Purpose purpose; bmp::Flags flags; - std::vector data; - bmp::Header header() const; + /// Returns data with consideration to flags. + std::vector get_readable_data() const; + + /// Sets flags (e.g. compression flag) if the data is above some threshold, + /// and compresses the data. + /// Returns the header needed to send this packet. + [[nodiscard]] bmp::Header finalize(); + + /// Raw (potentially compressed) data -- do not read directly to deserialize from. + std::vector raw_data; }; struct Client { @@ -37,7 +45,7 @@ struct Client { /// Reads a single packet from the TCP stream. Blocks all other reads (not writes). Packet tcp_read(); /// Writes the packet to the TCP stream. Blocks all other writes. - void tcp_write(const Packet& packet); + void tcp_write(Packet& packet); /// Writes the specified to the TCP stream without a header or any metadata - use in /// conjunction with something else. Blocks other writes. void tcp_write_file_raw(const std::filesystem::path& path); @@ -55,7 +63,6 @@ struct Client { const uint64_t udp_magic; private: - void tcp_main(); std::mutex m_tcp_read_mtx; @@ -83,7 +90,7 @@ public: /// Reads a packet from the given UDP socket, returning the client's endpoint as an out-argument. Packet udp_read(ip::udp::endpoint& out_ep); /// Sends a packet to the specified UDP endpoint via the UDP socket. - void udp_write(const Packet& packet, const ip::udp::endpoint& to_ep); + void udp_write(Packet& packet, const ip::udp::endpoint& to_ep); void disconnect(ClientID id, const std::string& msg); diff --git a/src/Network.cpp b/src/Network.cpp index 7210922..52581c0 100644 --- a/src/Network.cpp +++ b/src/Network.cpp @@ -1,6 +1,7 @@ #include "Network.h" #include "ClientInfo.h" #include "Common.h" +#include "Compression.h" #include "Environment.h" #include "Http.h" #include "LuaAPI.h" @@ -21,6 +22,52 @@ #include #endif +#include + +std::vector Packet::get_readable_data() const { + if ((flags & bmp::Flags::ZstdCompressed) != 0) { + return bmp::zstd_decompress(raw_data); + } else { + return raw_data; + } +} + +TEST_CASE("Packet finalize") { + Packet packet { + .purpose = bmp::Purpose::Invalid, + }; + SUBCASE("No compression, under threshold") { + packet.raw_data = std::vector(bmp::COMPRESSION_THRESHOLD - 1, 5); + (void)packet.finalize(); + // not compressed, still the same + CHECK(std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; })); + ; + // no compression flag + CHECK_EQ(packet.flags & bmp::Flags::ZstdCompressed, 0); + } + SUBCASE("Compression via threshold") { + packet.raw_data = std::vector(bmp::COMPRESSION_THRESHOLD + 1, 5); + (void)packet.finalize(); + // compressed, not the exact same + CHECK(!std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; })); + // decompressable + CHECK_NOTHROW(bmp::zstd_decompress(packet.raw_data)); + // compression flag set + CHECK_NE(packet.flags & bmp::Flags::ZstdCompressed, 0); + } + SUBCASE("Compression flag") { + packet.raw_data = std::vector(bmp::COMPRESSION_THRESHOLD - 1, 5); + packet.flags = bmp::Flags(packet.flags | bmp::Flags::ZstdCompressed); + (void)packet.finalize(); + // compressed, not the exact same + CHECK(!std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; })); + // decompressable + CHECK_NOTHROW(bmp::zstd_decompress(packet.raw_data)); + // compression flag set + CHECK_NE(packet.flags & bmp::Flags::ZstdCompressed, 0); + } +} + Packet Client::tcp_read() { std::unique_lock lock(m_tcp_read_mtx); Packet packet {}; @@ -29,24 +76,28 @@ Packet Client::tcp_read() { bmp::Header hdr {}; hdr.deserialize_from(header_buffer); // vector eaten up by now, recv again - packet.data.resize(hdr.size); - read(m_tcp_socket, buffer(packet.data)); + packet.raw_data.resize(hdr.size); + read(m_tcp_socket, buffer(packet.raw_data)); packet.purpose = hdr.purpose; packet.flags = hdr.flags; return packet; } -void Client::tcp_write(const Packet& packet) { +void Client::tcp_write(Packet& packet) { beammp_tracef("Sending 0x{:x} to {}", int(packet.purpose), id); + // acquire a lock to avoid writing a header, then being interrupted by another write std::unique_lock lock(m_tcp_write_mtx); - auto header = packet.header(); + // finalize the packet (compress etc) and produce header + auto header = packet.finalize(); + // serialize header std::vector header_data(bmp::Header::SERIALIZED_SIZE); if (header.flags != bmp::Flags::None) { beammp_errorf("Flags are not implemented"); } header.serialize_to(header_data); + // write header and packet data write(m_tcp_socket, buffer(header_data)); - write(m_tcp_socket, buffer(packet.data)); + write(m_tcp_socket, buffer(packet.raw_data)); } void Client::tcp_write_file_raw(const std::filesystem::path& path) { @@ -105,12 +156,19 @@ void Client::tcp_main() { beammp_debugf("TCP thread stopped for client {}", id); } -bmp::Header Packet::header() const { +bmp::Header Packet::finalize() { + // the user can force zstd compression on before setting data to force compression, + // otherwise the threshold is used. + if ((flags & bmp::Flags::ZstdCompressed) != 0 + || raw_data.size() > bmp::COMPRESSION_THRESHOLD) { + flags = bmp::Flags(flags | bmp::Flags::ZstdCompressed); + raw_data = bmp::zstd_compress(raw_data); + } return { .purpose = purpose, - .flags = bmp::Flags::None, + .flags = flags, .rsv = 0, - .size = static_cast(data.size()), + .size = static_cast(raw_data.size()), }; } @@ -125,16 +183,16 @@ Packet Network::udp_read(ip::udp::endpoint& out_ep) { beammp_errorf("Flags are not implemented"); return {}; } - packet.data.resize(header.size); - std::copy(s_buffer.begin() + offset, s_buffer.begin() + offset + header.size, packet.data.begin()); + packet.raw_data.resize(header.size); + std::copy(s_buffer.begin() + offset, s_buffer.begin() + offset + header.size, packet.raw_data.begin()); return packet; } -void Network::udp_write(const Packet& packet, const ip::udp::endpoint& to_ep) { - auto header = packet.header(); +void Network::udp_write(Packet& packet, const ip::udp::endpoint& to_ep) { + auto header = packet.finalize(); std::vector data(header.size + bmp::Header::SERIALIZED_SIZE); auto offset = header.serialize_to(data); - std::copy(packet.data.begin(), packet.data.end(), data.begin() + static_cast(offset)); + std::copy(packet.raw_data.begin(), packet.raw_data.end(), data.begin() + static_cast(offset)); m_udp_socket.send_to(buffer(data), to_ep, {}); } @@ -207,7 +265,7 @@ void Network::udp_read_main() { auto& magics = std::get<2>(all); ClientID id = 0xffffffff; uint64_t recv_magic; - bmp::deserialize(recv_magic, packet.data); + bmp::deserialize(recv_magic, packet.get_readable_data()); if (magics->contains(recv_magic)) { id = magics->at(recv_magic); magics->erase(recv_magic); @@ -304,7 +362,7 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar switch (packet.purpose) { case bmp::ProtocolVersion: { struct bmp::ProtocolVersion protocol_version { }; - protocol_version.deserialize_from(packet.data); + protocol_version.deserialize_from(packet.get_readable_data()); if (protocol_version.version.major != 1) { beammp_debugf("{}: Protocol version bad", id); // version bad @@ -325,7 +383,7 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar } case bmp::ClientInfo: { struct bmp::ClientInfo cinfo { }; - cinfo.deserialize_from(packet.data); + cinfo.deserialize_from(packet.get_readable_data()); beammp_debugf("{} is running game version: v{}.{}.{}, mod version: v{}.{}.{}, client implementation '{}' v{}.{}.{}", id, cinfo.game_version.major, @@ -352,9 +410,9 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar }; Packet sinfo_packet { .purpose = bmp::ServerInfo, - .data = std::vector(1024), + .raw_data = std::vector(1024), }; - sinfo.serialize_to(sinfo_packet.data); + sinfo.serialize_to(sinfo_packet.raw_data); client->tcp_write(sinfo_packet); // now transfer to next state Packet auth_state { @@ -413,7 +471,8 @@ void Network::authenticate_user(const std::string& public_key, std::shared_ptr& client) { switch (packet.purpose) { case bmp::Purpose::PlayerPublicKey: { - auto public_key = std::string(packet.data.begin(), packet.data.end()); + auto packet_data = packet.get_readable_data(); + auto public_key = std::string(packet_data.begin(), packet_data.end()); try { authenticate_user(public_key, client); } catch (const std::exception& e) { @@ -422,7 +481,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar beammp_errorf("Client {} failed to authenticate: {}", id, err); Packet auth_fail_packet { .purpose = bmp::Purpose::AuthFailed, - .data = std::vector(err.begin(), err.end()), + .raw_data = std::vector(err.begin(), err.end()), }; client->tcp_write(auth_fail_packet); disconnect(id, err); @@ -454,7 +513,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar } else if (NotAllowedWithReason) { Packet auth_fail_packet { .purpose = bmp::Purpose::PlayerRejected, - .data = std::vector(Reason.begin(), Reason.end()), + .raw_data = std::vector(Reason.begin(), Reason.end()), }; client->tcp_write(auth_fail_packet); disconnect(id, fmt::format("Rejected by a plugin for reason: {}", Reason)); @@ -464,22 +523,22 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar // send auth ok since auth succeeded Packet auth_ok { .purpose = bmp::Purpose::AuthOk, - .data = std::vector(4), + .raw_data = std::vector(4), }; // with the player id - bmp::serialize(client->id, auth_ok.data); + bmp::serialize(client->id, auth_ok.raw_data); client->tcp_write(auth_ok); // save the udp magic - m_udp_magics.emplace(client->udp_magic, client->id); + m_client_magics->emplace(client->udp_magic, client->id); // send the udp start packet, which should get the client to start udp with // this packet as the first message Packet udp_start { .purpose = bmp::Purpose::StartUDP, - .data = std::vector(8), + .raw_data = std::vector(8), }; - bmp::serialize(client->udp_magic, udp_start.data); + bmp::serialize(client->udp_magic, udp_start.raw_data); client->tcp_write(udp_start); // player must start udp to advance now, so no state change break;