ensure client stays referenced while referenced in thread pool

This commit is contained in:
Lion Kortlepel 2024-01-26 10:32:41 +01:00
parent c173ffdbdc
commit d0cc23333c
No known key found for this signature in database
GPG Key ID: 4322FF2B4C71259B
3 changed files with 63 additions and 45 deletions

View File

@ -28,7 +28,7 @@ using VehicleID = uint16_t;
using namespace boost::asio; using namespace boost::asio;
struct Client { struct Client : std::enable_shared_from_this<Client> {
using StrandPtr = std::shared_ptr<boost::asio::strand<ip::tcp::socket::executor_type>>; using StrandPtr = std::shared_ptr<boost::asio::strand<ip::tcp::socket::executor_type>>;
using Ptr = std::shared_ptr<Client>; using Ptr = std::shared_ptr<Client>;
@ -45,7 +45,8 @@ struct Client {
/// conjunction with something else. Blocks other writes. /// conjunction with something else. Blocks other writes.
void tcp_write_file_raw(const std::filesystem::path& path); void tcp_write_file_raw(const std::filesystem::path& path);
Client(ClientID id, class Network& network, ip::tcp::socket&& tcp_socket); /// Ensures no client is ever created as a non-shared-ptr, so that enable_shared_from_this works.
static Client::Ptr make_ptr(ClientID new_id, class Network& network, ip::tcp::socket&& tcp_socket);
~Client(); ~Client();
ip::tcp::socket& tcp_socket() { return m_tcp_socket; } ip::tcp::socket& tcp_socket() { return m_tcp_socket; }
@ -61,6 +62,8 @@ struct Client {
void set_udp_endpoint(const ip::udp::endpoint& ep) { m_udp_endpoint = ep; } void set_udp_endpoint(const ip::udp::endpoint& ep) { m_udp_endpoint = ep; }
private: private:
/// Ctor must be private to ensure all clients are constructed as shared_ptr to enable_shared_from_this.
Client(ClientID id, class Network& network, ip::tcp::socket&& tcp_socket);
/// Call this when the client seems to have timed out. Will send a ping and set a flag. /// Call this when the client seems to have timed out. Will send a ping and set a flag.
/// Returns true if try-again, false if the connection was closed. /// Returns true if try-again, false if the connection was closed.
[[nodiscard]] bool handle_timeout(); [[nodiscard]] bool handle_timeout();

View File

@ -14,7 +14,7 @@
using json = nlohmann::json; using json = nlohmann::json;
std::string Http::GET(const std::string& host, int port, const std::string& target, unsigned int* status) { std::string Http::GET(const std::string& host, int port, const std::string& target, unsigned int* status) {
httplib::SSLClient client(host, port); static thread_local httplib::SSLClient client(host, port);
client.enable_server_certificate_verification(false); client.enable_server_certificate_verification(false);
client.set_address_family(AF_INET); client.set_address_family(AF_INET);
auto res = client.Get(target.c_str()); auto res = client.Get(target.c_str());
@ -29,7 +29,7 @@ std::string Http::GET(const std::string& host, int port, const std::string& targ
} }
std::string Http::POST(const std::string& host, int port, const std::string& target, const std::string& body, const std::string& ContentType, unsigned int* status, const httplib::Headers& headers) { std::string Http::POST(const std::string& host, int port, const std::string& target, const std::string& body, const std::string& ContentType, unsigned int* status, const httplib::Headers& headers) {
httplib::SSLClient client(host, port); static thread_local httplib::SSLClient client(host, port);
client.set_read_timeout(std::chrono::seconds(10)); client.set_read_timeout(std::chrono::seconds(10));
beammp_assert(client.is_valid()); beammp_assert(client.is_valid());
client.enable_server_certificate_verification(false); client.enable_server_certificate_verification(false);

View File

@ -14,6 +14,7 @@
#include <boost/asio/buffer.hpp> #include <boost/asio/buffer.hpp>
#include <boost/asio/io_context.hpp> #include <boost/asio/io_context.hpp>
#include <boost/asio/read.hpp> #include <boost/asio/read.hpp>
#include <boost/asio/socket_base.hpp>
#include <boost/asio/thread_pool.hpp> #include <boost/asio/thread_pool.hpp>
#include <boost/chrono/duration.hpp> #include <boost/chrono/duration.hpp>
#include <boost/system/detail/errc.hpp> #include <boost/system/detail/errc.hpp>
@ -35,14 +36,14 @@
/// Boost::asio + strands + timer magic to make writes timeout after some time. /// Boost::asio + strands + timer magic to make writes timeout after some time.
template <typename HandlerFn> template <typename HandlerFn>
static void async_write_timeout(ip::tcp::socket& stream, const_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler, Client::StrandPtr strand) { static void async_write_timeout(Client::Ptr client, const_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler) {
struct TimeoutHelper : std::enable_shared_from_this<TimeoutHelper> { struct TimeoutHelper : std::enable_shared_from_this<TimeoutHelper> {
/// Given a socket (stream), buffer and a completion handler, constructs a state machine. /// Given a socket (stream), buffer and a completion handler, constructs a state machine.
TimeoutHelper(ip::tcp::socket& stream, const_buffer buffer, HandlerFn handler, Client::StrandPtr strand) TimeoutHelper(Client::Ptr client, ip::tcp::socket& stream, const_buffer buffer, HandlerFn handler)
: m_stream(stream) : m_client(client)
, m_stream(stream)
, m_buffer(std::move(buffer)) , m_buffer(std::move(buffer))
, m_handler_fn(std::move(handler)) , m_handler_fn(std::move(handler)) { }
, m_strand(std::move(strand)) { }
/// Kicks off the timer and async_write, which race to cancel each other. /// Kicks off the timer and async_write, which race to cancel each other.
/// Whichever completes first gets to cancel the other one. /// Whichever completes first gets to cancel the other one.
/// Effectively, the timer will finish before the write if the write is "timing out", /// Effectively, the timer will finish before the write if the write is "timing out",
@ -54,12 +55,12 @@ static void async_write_timeout(ip::tcp::socket& stream, const_buffer&& sequence
// as a copy so that it can call the timeout handler on this object. // as a copy so that it can call the timeout handler on this object.
// the whole thing is wrapped in a strand to avoid this happening on two separate thwrites at the same time, // the whole thing is wrapped in a strand to avoid this happening on two separate thwrites at the same time,
// i.e. the timer and write finish at the same time on separate thwrites, or other goofy stuff. // i.e. the timer and write finish at the same time on separate thwrites, or other goofy stuff.
m_timer.async_wait(bind_executor(*m_strand, [self = this->shared_from_this()](auto&& ec) { m_timer.async_wait(bind_executor(m_strand, [self = this->shared_from_this()](auto&& ec) {
self->handle_timeout(ec); self->handle_timeout(ec);
})); }));
// start the write on the same strand, again giving a copy of a shared_ptr to ourselves so the handler can be // start the write on the same strand, again giving a copy of a shared_ptr to ourselves so the handler can be
// called. // called.
boost::asio::async_write(m_stream, m_buffer, bind_executor(*m_strand, [self = this->shared_from_this()](auto&& ec, auto size) { boost::asio::async_write(m_stream, m_buffer, bind_executor(m_strand, [self = this->shared_from_this()](auto&& ec, auto size) {
self->handle_write(ec, size); self->handle_write(ec, size);
})); }));
} }
@ -89,28 +90,29 @@ static void async_write_timeout(ip::tcp::socket& stream, const_buffer&& sequence
m_handler_fn(ec, size); m_handler_fn(ec, size);
} }
Client::Ptr m_client;
ip::tcp::socket& m_stream; ip::tcp::socket& m_stream;
const_buffer m_buffer; const_buffer m_buffer;
HandlerFn m_handler_fn; HandlerFn m_handler_fn;
Client::StrandPtr m_strand; boost::asio::strand<ip::tcp::socket::executor_type> m_strand { m_stream.get_executor() };
boost::asio::deadline_timer m_timer { m_stream.get_executor() }; boost::asio::deadline_timer m_timer { m_stream.get_executor() };
bool m_completed = false; bool m_completed = false;
}; };
auto helper = std::make_shared<TimeoutHelper>(stream, auto helper = std::make_shared<TimeoutHelper>(client, client->tcp_socket(),
std::forward<const_buffer>(sequence), std::forward<const_buffer>(sequence),
std::forward<HandlerFn>(handler), std::forward<HandlerFn>(handler));
strand);
helper->start(timeout_ms); helper->start(timeout_ms);
} }
/// Boost::asio + strands + timer magic to make reads timeout after some time. /// Boost::asio + strands + timer magic to make reads timeout after some time.
template <typename HandlerFn> template <typename HandlerFn>
static void async_read_timeout(ip::tcp::socket& stream, mutable_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler) { static void async_read_timeout(Client::Ptr client, mutable_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler) {
struct TimeoutHelper : std::enable_shared_from_this<TimeoutHelper> { struct TimeoutHelper : std::enable_shared_from_this<TimeoutHelper> {
/// Given a socket (stream), buffer and a completion handler, constructs a state machine. /// Given a socket (stream), buffer and a completion handler, constructs a state machine.
TimeoutHelper(ip::tcp::socket& stream, mutable_buffer buffer, HandlerFn handler) TimeoutHelper(Client::Ptr client, ip::tcp::socket& stream, mutable_buffer buffer, HandlerFn handler)
: m_stream(stream) : m_client(client)
, m_stream(stream)
, m_buffer(std::move(buffer)) , m_buffer(std::move(buffer))
, m_handler_fn(std::move(handler)) { , m_handler_fn(std::move(handler)) {
} }
@ -160,6 +162,7 @@ static void async_read_timeout(ip::tcp::socket& stream, mutable_buffer&& sequenc
m_handler_fn(ec, size); m_handler_fn(ec, size);
} }
Client::Ptr m_client;
ip::tcp::socket& m_stream; ip::tcp::socket& m_stream;
mutable_buffer m_buffer; mutable_buffer m_buffer;
HandlerFn m_handler_fn; HandlerFn m_handler_fn;
@ -168,7 +171,7 @@ static void async_read_timeout(ip::tcp::socket& stream, mutable_buffer&& sequenc
bool m_completed = false; bool m_completed = false;
}; };
auto helper = std::make_shared<TimeoutHelper>(stream, auto helper = std::make_shared<TimeoutHelper>(client, client->tcp_socket(),
std::forward<mutable_buffer>(sequence), std::forward<mutable_buffer>(sequence),
std::forward<HandlerFn>(handler)); std::forward<HandlerFn>(handler));
helper->start(timeout_ms); helper->start(timeout_ms);
@ -184,8 +187,6 @@ void Network::send_to(ClientID id, bmp::Packet& packet) {
void Client::tcp_write(bmp::Packet& packet) { void Client::tcp_write(bmp::Packet& packet) {
beammp_tracef("Sending {} to {}", int(packet.purpose), id); beammp_tracef("Sending {} to {}", int(packet.purpose), id);
// acquire a lock to avoid writing a header, then being interrupted by another write
std::unique_lock lock(m_tcp_write_mtx);
// finalize the packet (compress etc) and produce header // finalize the packet (compress etc) and produce header
auto header = packet.finalize(); auto header = packet.finalize();
// data has to be a shared_ptr, because we pass it to the async write function which completes later, // data has to be a shared_ptr, because we pass it to the async write function which completes later,
@ -198,13 +199,12 @@ void Client::tcp_write(bmp::Packet& packet) {
beammp_tracef("Packet of size {} B given a timeout of {}ms ({}s)", data->size(), timeout.total_milliseconds(), timeout.seconds()); beammp_tracef("Packet of size {} B given a timeout of {}ms ({}s)", data->size(), timeout.total_milliseconds(), timeout.seconds());
// write header and packet data // write header and packet data
async_write_timeout( async_write_timeout(
m_tcp_socket, buffer(*data), timeout, [data, this](const boost::system::error_code& ec, size_t) { shared_from_this(), buffer(*data), timeout, [data, this](const boost::system::error_code& ec, size_t) {
if (ec && ec.value() == boost::system::errc::operation_canceled) { if (ec && ec.value() == boost::system::errc::operation_canceled) {
// write timeout is fatal // write timeout is fatal
m_network.disconnect(id, "Write timeout"); m_network.disconnect(id, "Write timeout");
} }
}, });
m_tcp_strand);
} }
void Client::tcp_write_file_raw(const std::filesystem::path& path) { void Client::tcp_write_file_raw(const std::filesystem::path& path) {
@ -232,7 +232,10 @@ void Client::tcp_write_file_raw(const std::filesystem::path& path) {
Client::~Client() { Client::~Client() {
beammp_debugf("Client {} shutting down", id); beammp_debugf("Client {} shutting down", id);
m_tcp_socket.shutdown(boost::asio::socket_base::shutdown_receive); try {
m_tcp_socket.shutdown(boost::asio::socket_base::shutdown_receive);
} catch (...) {
}
m_tcp_thread.interrupt(); m_tcp_thread.interrupt();
beammp_debugf("Client {} shut down", id); beammp_debugf("Client {} shut down", id);
} }
@ -265,7 +268,7 @@ void Client::start_tcp() {
beammp_tracef("{}", __func__); beammp_tracef("{}", __func__);
m_header.resize(bmp::Header::SERIALIZED_SIZE); m_header.resize(bmp::Header::SERIALIZED_SIZE);
beammp_tracef("Header buffer size: {}", m_header.size()); beammp_tracef("Header buffer size: {}", m_header.size());
async_read_timeout(m_tcp_socket, buffer(m_header), m_read_timeout, [this](const boost::system::error_code& ec, size_t) { async_read_timeout(shared_from_this(), buffer(m_header), m_read_timeout, [this](const boost::system::error_code& ec, size_t) {
if (ec && ec.value() == boost::system::errc::operation_canceled) { if (ec && ec.value() == boost::system::errc::operation_canceled) {
beammp_warnf("Client {} possibly timing out", id); beammp_warnf("Client {} possibly timing out", id);
if (handle_timeout()) { if (handle_timeout()) {
@ -288,7 +291,7 @@ void Client::start_tcp() {
m_packet.flags = hdr.flags; m_packet.flags = hdr.flags;
m_packet.raw_data.resize(hdr.size); m_packet.raw_data.resize(hdr.size);
beammp_tracef("Raw data buffer size: {}", m_packet.raw_data.size()); beammp_tracef("Raw data buffer size: {}", m_packet.raw_data.size());
async_read_timeout(m_tcp_socket, buffer(m_packet.raw_data), m_read_timeout, [this](const boost::system::error_code& ec, size_t bytes) { async_read_timeout(shared_from_this(), buffer(m_packet.raw_data), m_read_timeout, [this](const boost::system::error_code& ec, size_t bytes) {
if (ec && ec.value() == boost::system::errc::operation_canceled) { if (ec && ec.value() == boost::system::errc::operation_canceled) {
beammp_warnf("Client {} possibly timing out after sending header", id); beammp_warnf("Client {} possibly timing out after sending header", id);
if (handle_timeout()) { if (handle_timeout()) {
@ -411,7 +414,7 @@ void Network::handle_accept(const boost::system::error_code& ec) {
} else { } else {
auto new_id = new_client_id(); auto new_id = new_client_id();
beammp_debugf("New connection from {}", m_temp_socket.remote_endpoint().address().to_string(), m_temp_socket.remote_endpoint().port()); beammp_debugf("New connection from {}", m_temp_socket.remote_endpoint().address().to_string(), m_temp_socket.remote_endpoint().port());
std::shared_ptr<Client> new_client(std::make_shared<Client>(new_id, *this, std::move(m_temp_socket))); Client::Ptr new_client = Client::make_ptr(new_id, *this, std::move(m_temp_socket));
m_clients->emplace(new_id, new_client); m_clients->emplace(new_id, new_client);
new_client->start_tcp(); new_client->start_tcp();
} }
@ -493,27 +496,36 @@ void Network::udp_read_main() {
void Network::disconnect(ClientID id, const std::string& msg) { void Network::disconnect(ClientID id, const std::string& msg) {
// this has to be scheduled, because the thread which did this cannot!!! do it itself. // this has to be scheduled, because the thread which did this cannot!!! do it itself.
beammp_debugf("Scheduling disconnect for {}", id); beammp_debugf("Scheduling disconnect for {}", id);
post(context(), [id, msg, this] { // grab the client ptr
beammp_infof("Disconnecting client {}: {}", id, msg); auto clients = m_clients.synchronize();
// deadlock-free algorithm to acquire a lock on all these if (clients->contains(id)) {
// this is a little ugly but saves a headache here in the future auto client = clients->at(id);
auto all = boost::synchronize(m_clients, m_udp_endpoints, m_client_magics); post(context(), [client, id, msg, this] {
auto& clients = std::get<0>(all); beammp_infof("Disconnecting client {}: {}", id, msg);
auto& endpoints = std::get<1>(all); // deadlock-free algorithm to acquire a lock on all these
auto& magics = std::get<2>(all); // this is a little ugly but saves a headache here in the future
auto all = boost::synchronize(m_clients, m_udp_endpoints, m_client_magics);
auto& clients = std::get<0>(all);
auto& endpoints = std::get<1>(all);
auto& magics = std::get<2>(all);
if (clients->contains(id)) {
auto client = clients->at(id);
beammp_debugf("Removing client udp magic {}", client->udp_magic); beammp_debugf("Removing client udp magic {}", client->udp_magic);
magics->erase(client->udp_magic); magics->erase(client->udp_magic);
}
std::erase_if(*endpoints, [&](const auto& item) { std::erase_if(*endpoints, [&](const auto& item) {
const auto& [key, value] = item; const auto& [key, value] = item;
return value == id; return value == id;
});
// TODO: Despawn vehicles owned by this player
clients->erase(id);
try {
client->tcp_socket().shutdown(boost::asio::socket_base::shutdown_both);
client->tcp_socket().close();
} catch (...) { }
}); });
// TODO: Despawn vehicles owned by this player } else {
clients->erase(id); beammp_debugf("Client {} already disconnected", id);
}); }
} }
std::unordered_map<ClientID, Client::Ptr> Network::playing_clients() const { std::unordered_map<ClientID, Client::Ptr> Network::playing_clients() const {
@ -1234,3 +1246,6 @@ void Vehicle::update_status(std::span<const uint8_t> raw_packet) {
m_status_data.resize(raw_packet.size()); m_status_data.resize(raw_packet.size());
std::copy(raw_packet.begin(), raw_packet.end(), m_status_data.begin()); std::copy(raw_packet.begin(), raw_packet.end(), m_status_data.begin());
} }
Client::Ptr Client::make_ptr(ClientID new_id, class Network& network, ip::tcp::socket&& tcp_socket) {
return Client::Ptr(new Client(new_id, network, std::forward<ip::tcp::socket&&>(tcp_socket)));
}