ensure client stays referenced while referenced in thread pool

This commit is contained in:
Lion Kortlepel 2024-01-26 10:32:41 +01:00
parent c173ffdbdc
commit d0cc23333c
No known key found for this signature in database
GPG Key ID: 4322FF2B4C71259B
3 changed files with 63 additions and 45 deletions

View File

@ -28,7 +28,7 @@ using VehicleID = uint16_t;
using namespace boost::asio;
struct Client {
struct Client : std::enable_shared_from_this<Client> {
using StrandPtr = std::shared_ptr<boost::asio::strand<ip::tcp::socket::executor_type>>;
using Ptr = std::shared_ptr<Client>;
@ -45,7 +45,8 @@ struct Client {
/// conjunction with something else. Blocks other writes.
void tcp_write_file_raw(const std::filesystem::path& path);
Client(ClientID id, class Network& network, ip::tcp::socket&& tcp_socket);
/// Ensures no client is ever created as a non-shared-ptr, so that enable_shared_from_this works.
static Client::Ptr make_ptr(ClientID new_id, class Network& network, ip::tcp::socket&& tcp_socket);
~Client();
ip::tcp::socket& tcp_socket() { return m_tcp_socket; }
@ -61,6 +62,8 @@ struct Client {
void set_udp_endpoint(const ip::udp::endpoint& ep) { m_udp_endpoint = ep; }
private:
/// Ctor must be private to ensure all clients are constructed as shared_ptr to enable_shared_from_this.
Client(ClientID id, class Network& network, ip::tcp::socket&& tcp_socket);
/// Call this when the client seems to have timed out. Will send a ping and set a flag.
/// Returns true if try-again, false if the connection was closed.
[[nodiscard]] bool handle_timeout();

View File

@ -14,7 +14,7 @@
using json = nlohmann::json;
std::string Http::GET(const std::string& host, int port, const std::string& target, unsigned int* status) {
httplib::SSLClient client(host, port);
static thread_local httplib::SSLClient client(host, port);
client.enable_server_certificate_verification(false);
client.set_address_family(AF_INET);
auto res = client.Get(target.c_str());
@ -29,7 +29,7 @@ std::string Http::GET(const std::string& host, int port, const std::string& targ
}
std::string Http::POST(const std::string& host, int port, const std::string& target, const std::string& body, const std::string& ContentType, unsigned int* status, const httplib::Headers& headers) {
httplib::SSLClient client(host, port);
static thread_local httplib::SSLClient client(host, port);
client.set_read_timeout(std::chrono::seconds(10));
beammp_assert(client.is_valid());
client.enable_server_certificate_verification(false);

View File

@ -14,6 +14,7 @@
#include <boost/asio/buffer.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/read.hpp>
#include <boost/asio/socket_base.hpp>
#include <boost/asio/thread_pool.hpp>
#include <boost/chrono/duration.hpp>
#include <boost/system/detail/errc.hpp>
@ -35,14 +36,14 @@
/// Boost::asio + strands + timer magic to make writes timeout after some time.
template <typename HandlerFn>
static void async_write_timeout(ip::tcp::socket& stream, const_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler, Client::StrandPtr strand) {
static void async_write_timeout(Client::Ptr client, const_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler) {
struct TimeoutHelper : std::enable_shared_from_this<TimeoutHelper> {
/// Given a socket (stream), buffer and a completion handler, constructs a state machine.
TimeoutHelper(ip::tcp::socket& stream, const_buffer buffer, HandlerFn handler, Client::StrandPtr strand)
: m_stream(stream)
TimeoutHelper(Client::Ptr client, ip::tcp::socket& stream, const_buffer buffer, HandlerFn handler)
: m_client(client)
, m_stream(stream)
, m_buffer(std::move(buffer))
, m_handler_fn(std::move(handler))
, m_strand(std::move(strand)) { }
, m_handler_fn(std::move(handler)) { }
/// Kicks off the timer and async_write, which race to cancel each other.
/// Whichever completes first gets to cancel the other one.
/// Effectively, the timer will finish before the write if the write is "timing out",
@ -54,12 +55,12 @@ static void async_write_timeout(ip::tcp::socket& stream, const_buffer&& sequence
// as a copy so that it can call the timeout handler on this object.
// the whole thing is wrapped in a strand to avoid this happening on two separate thwrites at the same time,
// i.e. the timer and write finish at the same time on separate thwrites, or other goofy stuff.
m_timer.async_wait(bind_executor(*m_strand, [self = this->shared_from_this()](auto&& ec) {
m_timer.async_wait(bind_executor(m_strand, [self = this->shared_from_this()](auto&& ec) {
self->handle_timeout(ec);
}));
// start the write on the same strand, again giving a copy of a shared_ptr to ourselves so the handler can be
// called.
boost::asio::async_write(m_stream, m_buffer, bind_executor(*m_strand, [self = this->shared_from_this()](auto&& ec, auto size) {
boost::asio::async_write(m_stream, m_buffer, bind_executor(m_strand, [self = this->shared_from_this()](auto&& ec, auto size) {
self->handle_write(ec, size);
}));
}
@ -89,28 +90,29 @@ static void async_write_timeout(ip::tcp::socket& stream, const_buffer&& sequence
m_handler_fn(ec, size);
}
Client::Ptr m_client;
ip::tcp::socket& m_stream;
const_buffer m_buffer;
HandlerFn m_handler_fn;
Client::StrandPtr m_strand;
boost::asio::strand<ip::tcp::socket::executor_type> m_strand { m_stream.get_executor() };
boost::asio::deadline_timer m_timer { m_stream.get_executor() };
bool m_completed = false;
};
auto helper = std::make_shared<TimeoutHelper>(stream,
auto helper = std::make_shared<TimeoutHelper>(client, client->tcp_socket(),
std::forward<const_buffer>(sequence),
std::forward<HandlerFn>(handler),
strand);
std::forward<HandlerFn>(handler));
helper->start(timeout_ms);
}
/// Boost::asio + strands + timer magic to make reads timeout after some time.
template <typename HandlerFn>
static void async_read_timeout(ip::tcp::socket& stream, mutable_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler) {
static void async_read_timeout(Client::Ptr client, mutable_buffer&& sequence, boost::posix_time::milliseconds timeout_ms, HandlerFn&& handler) {
struct TimeoutHelper : std::enable_shared_from_this<TimeoutHelper> {
/// Given a socket (stream), buffer and a completion handler, constructs a state machine.
TimeoutHelper(ip::tcp::socket& stream, mutable_buffer buffer, HandlerFn handler)
: m_stream(stream)
TimeoutHelper(Client::Ptr client, ip::tcp::socket& stream, mutable_buffer buffer, HandlerFn handler)
: m_client(client)
, m_stream(stream)
, m_buffer(std::move(buffer))
, m_handler_fn(std::move(handler)) {
}
@ -160,6 +162,7 @@ static void async_read_timeout(ip::tcp::socket& stream, mutable_buffer&& sequenc
m_handler_fn(ec, size);
}
Client::Ptr m_client;
ip::tcp::socket& m_stream;
mutable_buffer m_buffer;
HandlerFn m_handler_fn;
@ -168,7 +171,7 @@ static void async_read_timeout(ip::tcp::socket& stream, mutable_buffer&& sequenc
bool m_completed = false;
};
auto helper = std::make_shared<TimeoutHelper>(stream,
auto helper = std::make_shared<TimeoutHelper>(client, client->tcp_socket(),
std::forward<mutable_buffer>(sequence),
std::forward<HandlerFn>(handler));
helper->start(timeout_ms);
@ -184,8 +187,6 @@ void Network::send_to(ClientID id, bmp::Packet& packet) {
void Client::tcp_write(bmp::Packet& packet) {
beammp_tracef("Sending {} to {}", int(packet.purpose), id);
// acquire a lock to avoid writing a header, then being interrupted by another write
std::unique_lock lock(m_tcp_write_mtx);
// finalize the packet (compress etc) and produce header
auto header = packet.finalize();
// data has to be a shared_ptr, because we pass it to the async write function which completes later,
@ -198,13 +199,12 @@ void Client::tcp_write(bmp::Packet& packet) {
beammp_tracef("Packet of size {} B given a timeout of {}ms ({}s)", data->size(), timeout.total_milliseconds(), timeout.seconds());
// write header and packet data
async_write_timeout(
m_tcp_socket, buffer(*data), timeout, [data, this](const boost::system::error_code& ec, size_t) {
shared_from_this(), buffer(*data), timeout, [data, this](const boost::system::error_code& ec, size_t) {
if (ec && ec.value() == boost::system::errc::operation_canceled) {
// write timeout is fatal
m_network.disconnect(id, "Write timeout");
}
},
m_tcp_strand);
});
}
void Client::tcp_write_file_raw(const std::filesystem::path& path) {
@ -232,7 +232,10 @@ void Client::tcp_write_file_raw(const std::filesystem::path& path) {
Client::~Client() {
beammp_debugf("Client {} shutting down", id);
m_tcp_socket.shutdown(boost::asio::socket_base::shutdown_receive);
try {
m_tcp_socket.shutdown(boost::asio::socket_base::shutdown_receive);
} catch (...) {
}
m_tcp_thread.interrupt();
beammp_debugf("Client {} shut down", id);
}
@ -265,7 +268,7 @@ void Client::start_tcp() {
beammp_tracef("{}", __func__);
m_header.resize(bmp::Header::SERIALIZED_SIZE);
beammp_tracef("Header buffer size: {}", m_header.size());
async_read_timeout(m_tcp_socket, buffer(m_header), m_read_timeout, [this](const boost::system::error_code& ec, size_t) {
async_read_timeout(shared_from_this(), buffer(m_header), m_read_timeout, [this](const boost::system::error_code& ec, size_t) {
if (ec && ec.value() == boost::system::errc::operation_canceled) {
beammp_warnf("Client {} possibly timing out", id);
if (handle_timeout()) {
@ -288,7 +291,7 @@ void Client::start_tcp() {
m_packet.flags = hdr.flags;
m_packet.raw_data.resize(hdr.size);
beammp_tracef("Raw data buffer size: {}", m_packet.raw_data.size());
async_read_timeout(m_tcp_socket, buffer(m_packet.raw_data), m_read_timeout, [this](const boost::system::error_code& ec, size_t bytes) {
async_read_timeout(shared_from_this(), buffer(m_packet.raw_data), m_read_timeout, [this](const boost::system::error_code& ec, size_t bytes) {
if (ec && ec.value() == boost::system::errc::operation_canceled) {
beammp_warnf("Client {} possibly timing out after sending header", id);
if (handle_timeout()) {
@ -411,7 +414,7 @@ void Network::handle_accept(const boost::system::error_code& ec) {
} else {
auto new_id = new_client_id();
beammp_debugf("New connection from {}", m_temp_socket.remote_endpoint().address().to_string(), m_temp_socket.remote_endpoint().port());
std::shared_ptr<Client> new_client(std::make_shared<Client>(new_id, *this, std::move(m_temp_socket)));
Client::Ptr new_client = Client::make_ptr(new_id, *this, std::move(m_temp_socket));
m_clients->emplace(new_id, new_client);
new_client->start_tcp();
}
@ -493,27 +496,36 @@ void Network::udp_read_main() {
void Network::disconnect(ClientID id, const std::string& msg) {
// this has to be scheduled, because the thread which did this cannot!!! do it itself.
beammp_debugf("Scheduling disconnect for {}", id);
post(context(), [id, msg, this] {
beammp_infof("Disconnecting client {}: {}", id, msg);
// deadlock-free algorithm to acquire a lock on all these
// this is a little ugly but saves a headache here in the future
auto all = boost::synchronize(m_clients, m_udp_endpoints, m_client_magics);
auto& clients = std::get<0>(all);
auto& endpoints = std::get<1>(all);
auto& magics = std::get<2>(all);
// grab the client ptr
auto clients = m_clients.synchronize();
if (clients->contains(id)) {
auto client = clients->at(id);
post(context(), [client, id, msg, this] {
beammp_infof("Disconnecting client {}: {}", id, msg);
// deadlock-free algorithm to acquire a lock on all these
// this is a little ugly but saves a headache here in the future
auto all = boost::synchronize(m_clients, m_udp_endpoints, m_client_magics);
auto& clients = std::get<0>(all);
auto& endpoints = std::get<1>(all);
auto& magics = std::get<2>(all);
if (clients->contains(id)) {
auto client = clients->at(id);
beammp_debugf("Removing client udp magic {}", client->udp_magic);
magics->erase(client->udp_magic);
}
std::erase_if(*endpoints, [&](const auto& item) {
const auto& [key, value] = item;
return value == id;
std::erase_if(*endpoints, [&](const auto& item) {
const auto& [key, value] = item;
return value == id;
});
// TODO: Despawn vehicles owned by this player
clients->erase(id);
try {
client->tcp_socket().shutdown(boost::asio::socket_base::shutdown_both);
client->tcp_socket().close();
} catch (...) { }
});
// TODO: Despawn vehicles owned by this player
clients->erase(id);
});
} else {
beammp_debugf("Client {} already disconnected", id);
}
}
std::unordered_map<ClientID, Client::Ptr> Network::playing_clients() const {
@ -1234,3 +1246,6 @@ void Vehicle::update_status(std::span<const uint8_t> raw_packet) {
m_status_data.resize(raw_packet.size());
std::copy(raw_packet.begin(), raw_packet.end(), m_status_data.begin());
}
Client::Ptr Client::make_ptr(ClientID new_id, class Network& network, ip::tcp::socket&& tcp_socket) {
return Client::Ptr(new Client(new_id, network, std::forward<ip::tcp::socket&&>(tcp_socket)));
}