mirror of
https://github.com/moonlight-stream/moonlight-common-c.git
synced 2025-08-18 09:25:49 +00:00
Switch to poll() to be safe against stack corruption from exceeding FD_SETSIZE
Fortunately, the fd_set definition is not prone to stack corruption on Windows, because FD_SETSIZE is the maximum number of sockets in a fd_set, not the maximum value socket file descriptor that can be stored in a fd_set. https://beesbuzz.biz/code/5739-The-problem-with-select-vs-poll
This commit is contained in:
parent
84f2421fbf
commit
117a7c637f
2
enet
2
enet
@ -1 +1 @@
|
||||
Subproject commit 8d794daa7cce21151ee16cd802bc138f63f7953f
|
||||
Subproject commit 757933e7bc80390edf03d741d1c57e070194f212
|
@ -149,7 +149,7 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref
|
||||
if (err != EWOULDBLOCK && err != EAGAIN) {
|
||||
Limelog("Failed to start async connect to TCP %u: %d\n", LiGetPortFromPortFlagIndex(i), err);
|
||||
|
||||
// Mask off this bit so we don't try to include it in select() below
|
||||
// Mask off this bit so we don't try to include it in pollSockets() below
|
||||
testPortFlags &= ~(1 << i);
|
||||
}
|
||||
}
|
||||
@ -165,7 +165,7 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref
|
||||
err = (int)LastSocketError();
|
||||
Limelog("Failed to send test packet to UDP %u: %d\n", LiGetPortFromPortFlagIndex(i), err);
|
||||
|
||||
// Mask off this bit so we don't try to include it in select() below
|
||||
// Mask off this bit so we don't try to include it in pollSockets() below
|
||||
testPortFlags &= ~(1 << i);
|
||||
|
||||
break;
|
||||
@ -177,103 +177,105 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref
|
||||
}
|
||||
}
|
||||
|
||||
// Continue to call select() until we have no more sockets to wait for,
|
||||
// or our select() call times out.
|
||||
// Continue to call pollSockets() until we have no more sockets to wait for,
|
||||
// or our pollSockets() call times out.
|
||||
while (testPortFlags != 0) {
|
||||
SOCKET nfds;
|
||||
fd_set readfds, writefds, exceptfds;
|
||||
struct timeval tv;
|
||||
int nfds;
|
||||
struct pollfd pfds[PORT_FLAGS_MAX_COUNT];
|
||||
|
||||
nfds = 0;
|
||||
FD_ZERO(&readfds);
|
||||
FD_ZERO(&writefds);
|
||||
FD_ZERO(&exceptfds);
|
||||
|
||||
// Fill out our FD sets
|
||||
for (i = 0; i < PORT_FLAGS_MAX_COUNT; i++) {
|
||||
if (testPortFlags & (1 << i)) {
|
||||
pfds[nfds].fd = sockets[i];
|
||||
|
||||
if (LiGetProtocolFromPortFlagIndex(i) == IPPROTO_UDP) {
|
||||
// Watch for readability on UDP sockets
|
||||
FD_SET(sockets[i], &readfds);
|
||||
if (sockets[i] + 1 > nfds) {
|
||||
nfds = sockets[i] + 1;
|
||||
}
|
||||
pfds[nfds].events = POLLIN;
|
||||
}
|
||||
else {
|
||||
// Watch for writeability or exceptions on TCP sockets
|
||||
FD_SET(sockets[i], &writefds);
|
||||
FD_SET(sockets[i], &exceptfds);
|
||||
if (sockets[i] + 1 > nfds) {
|
||||
nfds = sockets[i] + 1;
|
||||
}
|
||||
// Watch for writeability on TCP sockets
|
||||
pfds[nfds].events = POLLOUT;
|
||||
}
|
||||
|
||||
nfds++;
|
||||
}
|
||||
}
|
||||
|
||||
tv.tv_sec = TEST_PORT_TIMEOUT_SEC;
|
||||
tv.tv_usec = 0;
|
||||
|
||||
// Wait for the to complete or the timeout to elapse.
|
||||
// NB: The timeout resets each time we get a valid response on a port,
|
||||
// but that's probably fine.
|
||||
err = select((int)nfds, &readfds, &writefds, &exceptfds, &tv);
|
||||
err = pollSockets(pfds, nfds, TEST_PORT_TIMEOUT_SEC * 1000);
|
||||
if (err < 0) {
|
||||
// select() failed
|
||||
// pollSockets() failed
|
||||
err = LastSocketError();
|
||||
Limelog("select() failed: %d\n", err);
|
||||
Limelog("pollSockets() failed: %d\n", err);
|
||||
failingPortFlags = ML_TEST_RESULT_INCONCLUSIVE;
|
||||
goto Exit;
|
||||
}
|
||||
else if (err == 0) {
|
||||
// select() timed out
|
||||
Limelog("select() timed out after %d seconds\n", TEST_PORT_TIMEOUT_SEC);
|
||||
// pollSockets() timed out
|
||||
Limelog("Connection timed out after %d seconds\n", TEST_PORT_TIMEOUT_SEC);
|
||||
break;
|
||||
}
|
||||
|
||||
// We know something was signalled. Now we just need to find out what.
|
||||
for (i = 0; i < PORT_FLAGS_MAX_COUNT; i++) {
|
||||
if (testPortFlags & (1 << i)) {
|
||||
if (FD_ISSET(sockets[i], &writefds) || FD_ISSET(sockets[i], &exceptfds)) {
|
||||
// A TCP socket was signalled
|
||||
SOCKADDR_LEN len = sizeof(err);
|
||||
getsockopt(sockets[i], SOL_SOCKET, SO_ERROR, (char*)&err, &len);
|
||||
if (err != 0 || FD_ISSET(sockets[i], &exceptfds)) {
|
||||
// Get the error code
|
||||
err = (err != 0) ? err : LastSocketFail();
|
||||
}
|
||||
for (i = 0; i < nfds; i++) {
|
||||
if (pfds[i].revents != 0) {
|
||||
int portIndex;
|
||||
|
||||
// The TCP test has completed for this port
|
||||
testPortFlags &= ~(1 << i);
|
||||
if (err == 0) {
|
||||
// The TCP test was a success
|
||||
failingPortFlags &= ~(1 << i);
|
||||
|
||||
Limelog("TCP port %u test successful\n", LiGetPortFromPortFlagIndex(i));
|
||||
}
|
||||
else {
|
||||
Limelog("TCP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(i), err);
|
||||
// This socket was signalled. Figure out what port it was.
|
||||
for (portIndex = 0; portIndex < PORT_FLAGS_MAX_COUNT; portIndex++) {
|
||||
if (sockets[portIndex] == pfds[i].fd) {
|
||||
LC_ASSERT(testPortFlags & (1 << portIndex));
|
||||
break;
|
||||
}
|
||||
}
|
||||
else if (FD_ISSET(sockets[i], &readfds)) {
|
||||
|
||||
LC_ASSERT(portIndex != PORT_FLAGS_MAX_COUNT);
|
||||
|
||||
if (LiGetProtocolFromPortFlagIndex(portIndex) == IPPROTO_UDP) {
|
||||
char buf[32];
|
||||
|
||||
// A UDP socket was signalled. This could be because we got
|
||||
// a packet from the test server, or it could be because we
|
||||
// received an ICMP error which will be given to us from
|
||||
// recvfrom().
|
||||
testPortFlags &= ~(1 << i);
|
||||
testPortFlags &= ~(1 << portIndex);
|
||||
|
||||
// Check if the socket can be successfully read now
|
||||
err = recvfrom(sockets[i], buf, sizeof(buf), 0, NULL, NULL);
|
||||
err = recvfrom(sockets[portIndex], buf, sizeof(buf), 0, NULL, NULL);
|
||||
if (err >= 0) {
|
||||
// The UDP test was a success.
|
||||
failingPortFlags &= ~(1 << i);
|
||||
failingPortFlags &= ~(1 << portIndex);
|
||||
|
||||
Limelog("UDP port %u test successful\n", LiGetPortFromPortFlagIndex(i));
|
||||
Limelog("UDP port %u test successful\n", LiGetPortFromPortFlagIndex(portIndex));
|
||||
}
|
||||
else {
|
||||
err = LastSocketError();
|
||||
Limelog("UDP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(i), err);
|
||||
Limelog("UDP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(portIndex), err);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// A TCP socket was signalled
|
||||
SOCKADDR_LEN len = sizeof(err);
|
||||
getsockopt(sockets[portIndex], SOL_SOCKET, SO_ERROR, (char*)&err, &len);
|
||||
if (err != 0 || (pfds[i].revents & POLLERR)) {
|
||||
// Get the error code
|
||||
err = (err != 0) ? err : LastSocketFail();
|
||||
}
|
||||
|
||||
// The TCP test has completed for this port
|
||||
testPortFlags &= ~(1 << portIndex);
|
||||
if (err == 0) {
|
||||
// The TCP test was a success
|
||||
failingPortFlags &= ~(1 << portIndex);
|
||||
|
||||
Limelog("TCP port %u test successful\n", LiGetPortFromPortFlagIndex(portIndex));
|
||||
}
|
||||
else {
|
||||
Limelog("TCP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(portIndex), err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -65,20 +65,77 @@ void setRecvTimeout(SOCKET s, int timeoutSec) {
|
||||
}
|
||||
}
|
||||
|
||||
int recvUdpSocket(SOCKET s, char* buffer, int size, int useSelect) {
|
||||
fd_set readfds;
|
||||
int err;
|
||||
int pollSockets(struct pollfd* pollFds, int pollFdsCount, int timeoutMs) {
|
||||
#ifdef LC_WINDOWS
|
||||
// We could have used WSAPoll() but it has some nasty bugs
|
||||
// https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/
|
||||
//
|
||||
// We'll emulate WSAPoll() with select(). Fortunately, Microsoft's definition
|
||||
// of fd_set does not have the same stack corruption hazards that UNIX does.
|
||||
fd_set readFds, writeFds, exceptFds;
|
||||
int i, err;
|
||||
struct timeval tv;
|
||||
|
||||
FD_ZERO(&readFds);
|
||||
FD_ZERO(&writeFds);
|
||||
FD_ZERO(&exceptFds);
|
||||
|
||||
for (i = 0; i < pollFdsCount; i++) {
|
||||
// Clear revents on input like poll() does
|
||||
pollFds[i].revents = 0;
|
||||
|
||||
if (pollFds[i].events & POLLIN) {
|
||||
FD_SET(pollFds[i].fd, &readFds);
|
||||
}
|
||||
if (pollFds[i].events & POLLOUT) {
|
||||
FD_SET(pollFds[i].fd, &writeFds);
|
||||
|
||||
// Windows signals failed connections as an exception,
|
||||
// while Linux signals them as writeable.
|
||||
FD_SET(pollFds[i].fd, &exceptFds);
|
||||
}
|
||||
}
|
||||
|
||||
tv.tv_sec = timeoutMs / 1000;
|
||||
tv.tv_usec = (timeoutMs % 1000) * 1000;
|
||||
|
||||
// nfds is unused on Windows
|
||||
err = select(0, &readFds, &writeFds, &exceptFds, timeoutMs >= 0 ? &tv : NULL);
|
||||
if (err <= 0) {
|
||||
// Error or timeout
|
||||
return err;
|
||||
}
|
||||
|
||||
for (i = 0; i < pollFdsCount; i++) {
|
||||
if (FD_ISSET(pollFds[i].fd, &readFds)) {
|
||||
pollFds[i].revents |= POLLRDNORM;
|
||||
}
|
||||
|
||||
if (FD_ISSET(pollFds[i].fd, &writeFds)) {
|
||||
pollFds[i].revents |= POLLWRNORM;
|
||||
}
|
||||
|
||||
if (FD_ISSET(pollFds[i].fd, &exceptFds)) {
|
||||
pollFds[i].revents |= POLLERR;
|
||||
}
|
||||
}
|
||||
|
||||
return err;
|
||||
#else
|
||||
return poll(pollFds, pollFdsCount, timeoutMs);
|
||||
#endif
|
||||
}
|
||||
|
||||
int recvUdpSocket(SOCKET s, char* buffer, int size, int useSelect) {
|
||||
int err;
|
||||
|
||||
if (useSelect) {
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(s, &readfds);
|
||||
struct pollfd pfd;
|
||||
|
||||
// Wait up to 100 ms for the socket to be readable
|
||||
tv.tv_sec = 0;
|
||||
tv.tv_usec = UDP_RECV_POLL_TIMEOUT_MS * 1000;
|
||||
|
||||
err = select((int)(s) + 1, &readfds, NULL, NULL, &tv);
|
||||
pfd.fd = s;
|
||||
pfd.events = POLLIN;
|
||||
err = pollSockets(&pfd, 1, UDP_RECV_POLL_TIMEOUT_MS);
|
||||
if (err <= 0) {
|
||||
// Return if an error or timeout occurs
|
||||
return err;
|
||||
@ -263,39 +320,32 @@ SOCKET connectTcpSocket(struct sockaddr_storage* dstaddr, SOCKADDR_LEN addrlen,
|
||||
}
|
||||
|
||||
if (nonBlocking) {
|
||||
fd_set writefds, exceptfds;
|
||||
struct timeval tv;
|
||||
|
||||
FD_ZERO(&writefds);
|
||||
FD_ZERO(&exceptfds);
|
||||
FD_SET(s, &writefds);
|
||||
FD_SET(s, &exceptfds);
|
||||
|
||||
tv.tv_sec = timeoutSec;
|
||||
tv.tv_usec = 0;
|
||||
struct pollfd pfd;
|
||||
|
||||
// Wait for the connection to complete or the timeout to elapse
|
||||
err = select(s + 1, NULL, &writefds, &exceptfds, &tv);
|
||||
pfd.fd = s;
|
||||
pfd.events = POLLOUT;
|
||||
err = pollSockets(&pfd, 1, timeoutSec * 1000);
|
||||
if (err < 0) {
|
||||
// select() failed
|
||||
// pollSockets() failed
|
||||
err = LastSocketError();
|
||||
Limelog("select() failed: %d\n", err);
|
||||
Limelog("pollSockets() failed: %d\n", err);
|
||||
closeSocket(s);
|
||||
SetLastSocketError(err);
|
||||
return INVALID_SOCKET;
|
||||
}
|
||||
else if (err == 0) {
|
||||
// select() timed out
|
||||
Limelog("select() timed out after %d seconds\n", timeoutSec);
|
||||
// pollSockets() timed out
|
||||
Limelog("Connection timed out after %d seconds (TCP port %u)\n", timeoutSec, port);
|
||||
closeSocket(s);
|
||||
SetLastSocketError(ETIMEDOUT);
|
||||
return INVALID_SOCKET;
|
||||
}
|
||||
else if (FD_ISSET(s, &writefds) || FD_ISSET(s, &exceptfds)) {
|
||||
else {
|
||||
// The socket was signalled
|
||||
SOCKADDR_LEN len = sizeof(err);
|
||||
getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
|
||||
if (err != 0 || FD_ISSET(s, &exceptfds)) {
|
||||
if (err != 0 || (pfd.revents & POLLERR)) {
|
||||
// Get the error code
|
||||
err = (err != 0) ? err : LastSocketFail();
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ typedef int SOCKADDR_LEN;
|
||||
#include <netdb.h>
|
||||
#include <errno.h>
|
||||
#include <signal.h>
|
||||
#include <poll.h>
|
||||
|
||||
#define ioctlsocket ioctl
|
||||
#define LastSocketError() errno
|
||||
@ -59,6 +60,7 @@ int setNonFatalRecvTimeoutMs(SOCKET s, int timeoutMs);
|
||||
void setRecvTimeout(SOCKET s, int timeoutSec);
|
||||
void closeSocket(SOCKET s);
|
||||
int isPrivateNetworkAddress(struct sockaddr_storage* address);
|
||||
int pollSockets(struct pollfd* pollFds, int pollFdsCount, int timeoutMs);
|
||||
|
||||
int initializePlatformSockets(void);
|
||||
void cleanupPlatformSockets(void);
|
||||
|
Loading…
x
Reference in New Issue
Block a user