implement compression and decompression

This commit is contained in:
Lion Kortlepel
2024-01-17 14:52:09 +01:00
parent dbab9eb894
commit 9e99177fcb
3 changed files with 98 additions and 32 deletions

View File

@@ -19,9 +19,17 @@ using namespace boost::asio;
struct Packet {
bmp::Purpose purpose;
bmp::Flags flags;
std::vector<uint8_t> data;
bmp::Header header() const;
/// Returns data with consideration to flags.
std::vector<uint8_t> get_readable_data() const;
/// Sets flags (e.g. compression flag) if the data is above some threshold,
/// and compresses the data.
/// Returns the header needed to send this packet.
[[nodiscard]] bmp::Header finalize();
/// Raw (potentially compressed) data -- do not read directly to deserialize from.
std::vector<uint8_t> raw_data;
};
struct Client {
@@ -37,7 +45,7 @@ struct Client {
/// Reads a single packet from the TCP stream. Blocks all other reads (not writes).
Packet tcp_read();
/// Writes the packet to the TCP stream. Blocks all other writes.
void tcp_write(const Packet& packet);
void tcp_write(Packet& packet);
/// Writes the specified to the TCP stream without a header or any metadata - use in
/// conjunction with something else. Blocks other writes.
void tcp_write_file_raw(const std::filesystem::path& path);
@@ -55,7 +63,6 @@ struct Client {
const uint64_t udp_magic;
private:
void tcp_main();
std::mutex m_tcp_read_mtx;
@@ -83,7 +90,7 @@ public:
/// Reads a packet from the given UDP socket, returning the client's endpoint as an out-argument.
Packet udp_read(ip::udp::endpoint& out_ep);
/// Sends a packet to the specified UDP endpoint via the UDP socket.
void udp_write(const Packet& packet, const ip::udp::endpoint& to_ep);
void udp_write(Packet& packet, const ip::udp::endpoint& to_ep);
void disconnect(ClientID id, const std::string& msg);

View File

@@ -1,6 +1,7 @@
#include "Network.h"
#include "ClientInfo.h"
#include "Common.h"
#include "Compression.h"
#include "Environment.h"
#include "Http.h"
#include "LuaAPI.h"
@@ -21,6 +22,52 @@
#include <boost/iostreams/device/mapped_file.hpp>
#endif
#include <doctest/doctest.h>
std::vector<uint8_t> Packet::get_readable_data() const {
if ((flags & bmp::Flags::ZstdCompressed) != 0) {
return bmp::zstd_decompress(raw_data);
} else {
return raw_data;
}
}
TEST_CASE("Packet finalize") {
Packet packet {
.purpose = bmp::Purpose::Invalid,
};
SUBCASE("No compression, under threshold") {
packet.raw_data = std::vector<uint8_t>(bmp::COMPRESSION_THRESHOLD - 1, 5);
(void)packet.finalize();
// not compressed, still the same
CHECK(std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; }));
;
// no compression flag
CHECK_EQ(packet.flags & bmp::Flags::ZstdCompressed, 0);
}
SUBCASE("Compression via threshold") {
packet.raw_data = std::vector<uint8_t>(bmp::COMPRESSION_THRESHOLD + 1, 5);
(void)packet.finalize();
// compressed, not the exact same
CHECK(!std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; }));
// decompressable
CHECK_NOTHROW(bmp::zstd_decompress(packet.raw_data));
// compression flag set
CHECK_NE(packet.flags & bmp::Flags::ZstdCompressed, 0);
}
SUBCASE("Compression flag") {
packet.raw_data = std::vector<uint8_t>(bmp::COMPRESSION_THRESHOLD - 1, 5);
packet.flags = bmp::Flags(packet.flags | bmp::Flags::ZstdCompressed);
(void)packet.finalize();
// compressed, not the exact same
CHECK(!std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; }));
// decompressable
CHECK_NOTHROW(bmp::zstd_decompress(packet.raw_data));
// compression flag set
CHECK_NE(packet.flags & bmp::Flags::ZstdCompressed, 0);
}
}
Packet Client::tcp_read() {
std::unique_lock lock(m_tcp_read_mtx);
Packet packet {};
@@ -29,24 +76,28 @@ Packet Client::tcp_read() {
bmp::Header hdr {};
hdr.deserialize_from(header_buffer);
// vector eaten up by now, recv again
packet.data.resize(hdr.size);
read(m_tcp_socket, buffer(packet.data));
packet.raw_data.resize(hdr.size);
read(m_tcp_socket, buffer(packet.raw_data));
packet.purpose = hdr.purpose;
packet.flags = hdr.flags;
return packet;
}
void Client::tcp_write(const Packet& packet) {
void Client::tcp_write(Packet& packet) {
beammp_tracef("Sending 0x{:x} to {}", int(packet.purpose), id);
// acquire a lock to avoid writing a header, then being interrupted by another write
std::unique_lock lock(m_tcp_write_mtx);
auto header = packet.header();
// finalize the packet (compress etc) and produce header
auto header = packet.finalize();
// serialize header
std::vector<uint8_t> header_data(bmp::Header::SERIALIZED_SIZE);
if (header.flags != bmp::Flags::None) {
beammp_errorf("Flags are not implemented");
}
header.serialize_to(header_data);
// write header and packet data
write(m_tcp_socket, buffer(header_data));
write(m_tcp_socket, buffer(packet.data));
write(m_tcp_socket, buffer(packet.raw_data));
}
void Client::tcp_write_file_raw(const std::filesystem::path& path) {
@@ -105,12 +156,19 @@ void Client::tcp_main() {
beammp_debugf("TCP thread stopped for client {}", id);
}
bmp::Header Packet::header() const {
bmp::Header Packet::finalize() {
// the user can force zstd compression on before setting data to force compression,
// otherwise the threshold is used.
if ((flags & bmp::Flags::ZstdCompressed) != 0
|| raw_data.size() > bmp::COMPRESSION_THRESHOLD) {
flags = bmp::Flags(flags | bmp::Flags::ZstdCompressed);
raw_data = bmp::zstd_compress(raw_data);
}
return {
.purpose = purpose,
.flags = bmp::Flags::None,
.flags = flags,
.rsv = 0,
.size = static_cast<uint32_t>(data.size()),
.size = static_cast<uint32_t>(raw_data.size()),
};
}
@@ -125,16 +183,16 @@ Packet Network::udp_read(ip::udp::endpoint& out_ep) {
beammp_errorf("Flags are not implemented");
return {};
}
packet.data.resize(header.size);
std::copy(s_buffer.begin() + offset, s_buffer.begin() + offset + header.size, packet.data.begin());
packet.raw_data.resize(header.size);
std::copy(s_buffer.begin() + offset, s_buffer.begin() + offset + header.size, packet.raw_data.begin());
return packet;
}
void Network::udp_write(const Packet& packet, const ip::udp::endpoint& to_ep) {
auto header = packet.header();
void Network::udp_write(Packet& packet, const ip::udp::endpoint& to_ep) {
auto header = packet.finalize();
std::vector<uint8_t> data(header.size + bmp::Header::SERIALIZED_SIZE);
auto offset = header.serialize_to(data);
std::copy(packet.data.begin(), packet.data.end(), data.begin() + static_cast<long>(offset));
std::copy(packet.raw_data.begin(), packet.raw_data.end(), data.begin() + static_cast<long>(offset));
m_udp_socket.send_to(buffer(data), to_ep, {});
}
@@ -207,7 +265,7 @@ void Network::udp_read_main() {
auto& magics = std::get<2>(all);
ClientID id = 0xffffffff;
uint64_t recv_magic;
bmp::deserialize(recv_magic, packet.data);
bmp::deserialize(recv_magic, packet.get_readable_data());
if (magics->contains(recv_magic)) {
id = magics->at(recv_magic);
magics->erase(recv_magic);
@@ -304,7 +362,7 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar
switch (packet.purpose) {
case bmp::ProtocolVersion: {
struct bmp::ProtocolVersion protocol_version { };
protocol_version.deserialize_from(packet.data);
protocol_version.deserialize_from(packet.get_readable_data());
if (protocol_version.version.major != 1) {
beammp_debugf("{}: Protocol version bad", id);
// version bad
@@ -325,7 +383,7 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar
}
case bmp::ClientInfo: {
struct bmp::ClientInfo cinfo { };
cinfo.deserialize_from(packet.data);
cinfo.deserialize_from(packet.get_readable_data());
beammp_debugf("{} is running game version: v{}.{}.{}, mod version: v{}.{}.{}, client implementation '{}' v{}.{}.{}",
id,
cinfo.game_version.major,
@@ -352,9 +410,9 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar
};
Packet sinfo_packet {
.purpose = bmp::ServerInfo,
.data = std::vector<uint8_t>(1024),
.raw_data = std::vector<uint8_t>(1024),
};
sinfo.serialize_to(sinfo_packet.data);
sinfo.serialize_to(sinfo_packet.raw_data);
client->tcp_write(sinfo_packet);
// now transfer to next state
Packet auth_state {
@@ -413,7 +471,8 @@ void Network::authenticate_user(const std::string& public_key, std::shared_ptr<C
void Network::handle_authentication(ClientID id, const Packet& packet, std::shared_ptr<Client>& client) {
switch (packet.purpose) {
case bmp::Purpose::PlayerPublicKey: {
auto public_key = std::string(packet.data.begin(), packet.data.end());
auto packet_data = packet.get_readable_data();
auto public_key = std::string(packet_data.begin(), packet_data.end());
try {
authenticate_user(public_key, client);
} catch (const std::exception& e) {
@@ -422,7 +481,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar
beammp_errorf("Client {} failed to authenticate: {}", id, err);
Packet auth_fail_packet {
.purpose = bmp::Purpose::AuthFailed,
.data = std::vector<uint8_t>(err.begin(), err.end()),
.raw_data = std::vector<uint8_t>(err.begin(), err.end()),
};
client->tcp_write(auth_fail_packet);
disconnect(id, err);
@@ -454,7 +513,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar
} else if (NotAllowedWithReason) {
Packet auth_fail_packet {
.purpose = bmp::Purpose::PlayerRejected,
.data = std::vector<uint8_t>(Reason.begin(), Reason.end()),
.raw_data = std::vector<uint8_t>(Reason.begin(), Reason.end()),
};
client->tcp_write(auth_fail_packet);
disconnect(id, fmt::format("Rejected by a plugin for reason: {}", Reason));
@@ -464,22 +523,22 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar
// send auth ok since auth succeeded
Packet auth_ok {
.purpose = bmp::Purpose::AuthOk,
.data = std::vector<uint8_t>(4),
.raw_data = std::vector<uint8_t>(4),
};
// with the player id
bmp::serialize(client->id, auth_ok.data);
bmp::serialize(client->id, auth_ok.raw_data);
client->tcp_write(auth_ok);
// save the udp magic
m_udp_magics.emplace(client->udp_magic, client->id);
m_client_magics->emplace(client->udp_magic, client->id);
// send the udp start packet, which should get the client to start udp with
// this packet as the first message
Packet udp_start {
.purpose = bmp::Purpose::StartUDP,
.data = std::vector<uint8_t>(8),
.raw_data = std::vector<uint8_t>(8),
};
bmp::serialize(client->udp_magic, udp_start.data);
bmp::serialize(client->udp_magic, udp_start.raw_data);
client->tcp_write(udp_start);
// player must start udp to advance now, so no state change
break;