From 117a7c637f0d8e20aa2121c4bfdc8b3d8b1ed683 Mon Sep 17 00:00:00 2001 From: Cameron Gutman Date: Sun, 26 Jul 2020 19:53:52 -0700 Subject: [PATCH] Switch to poll() to be safe against stack corruption from exceeding FD_SETSIZE Fortunately, the fd_set definition is not prone to stack corruption on Windows, because FD_SETSIZE is the maximum number of sockets in a fd_set, not the maximum value socket file descriptor that can be stored in a fd_set. https://beesbuzz.biz/code/5739-The-problem-with-select-vs-poll --- enet | 2 +- src/ConnectionTester.c | 110 +++++++++++++++++++++-------------------- src/PlatformSockets.c | 104 ++++++++++++++++++++++++++++---------- src/PlatformSockets.h | 2 + 4 files changed, 136 insertions(+), 82 deletions(-) diff --git a/enet b/enet index 8d794da..757933e 160000 --- a/enet +++ b/enet @@ -1 +1 @@ -Subproject commit 8d794daa7cce21151ee16cd802bc138f63f7953f +Subproject commit 757933e7bc80390edf03d741d1c57e070194f212 diff --git a/src/ConnectionTester.c b/src/ConnectionTester.c index b4e6b01..71bebf6 100644 --- a/src/ConnectionTester.c +++ b/src/ConnectionTester.c @@ -149,7 +149,7 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref if (err != EWOULDBLOCK && err != EAGAIN) { Limelog("Failed to start async connect to TCP %u: %d\n", LiGetPortFromPortFlagIndex(i), err); - // Mask off this bit so we don't try to include it in select() below + // Mask off this bit so we don't try to include it in pollSockets() below testPortFlags &= ~(1 << i); } } @@ -165,7 +165,7 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref err = (int)LastSocketError(); Limelog("Failed to send test packet to UDP %u: %d\n", LiGetPortFromPortFlagIndex(i), err); - // Mask off this bit so we don't try to include it in select() below + // Mask off this bit so we don't try to include it in pollSockets() below testPortFlags &= ~(1 << i); break; @@ -177,103 +177,105 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref } } - // Continue to call select() until we have no more sockets to wait for, - // or our select() call times out. + // Continue to call pollSockets() until we have no more sockets to wait for, + // or our pollSockets() call times out. while (testPortFlags != 0) { - SOCKET nfds; - fd_set readfds, writefds, exceptfds; - struct timeval tv; + int nfds; + struct pollfd pfds[PORT_FLAGS_MAX_COUNT]; nfds = 0; - FD_ZERO(&readfds); - FD_ZERO(&writefds); - FD_ZERO(&exceptfds); // Fill out our FD sets for (i = 0; i < PORT_FLAGS_MAX_COUNT; i++) { if (testPortFlags & (1 << i)) { + pfds[nfds].fd = sockets[i]; + if (LiGetProtocolFromPortFlagIndex(i) == IPPROTO_UDP) { // Watch for readability on UDP sockets - FD_SET(sockets[i], &readfds); - if (sockets[i] + 1 > nfds) { - nfds = sockets[i] + 1; - } + pfds[nfds].events = POLLIN; } else { - // Watch for writeability or exceptions on TCP sockets - FD_SET(sockets[i], &writefds); - FD_SET(sockets[i], &exceptfds); - if (sockets[i] + 1 > nfds) { - nfds = sockets[i] + 1; - } + // Watch for writeability on TCP sockets + pfds[nfds].events = POLLOUT; } + + nfds++; } } - tv.tv_sec = TEST_PORT_TIMEOUT_SEC; - tv.tv_usec = 0; - // Wait for the to complete or the timeout to elapse. // NB: The timeout resets each time we get a valid response on a port, // but that's probably fine. - err = select((int)nfds, &readfds, &writefds, &exceptfds, &tv); + err = pollSockets(pfds, nfds, TEST_PORT_TIMEOUT_SEC * 1000); if (err < 0) { - // select() failed + // pollSockets() failed err = LastSocketError(); - Limelog("select() failed: %d\n", err); + Limelog("pollSockets() failed: %d\n", err); failingPortFlags = ML_TEST_RESULT_INCONCLUSIVE; goto Exit; } else if (err == 0) { - // select() timed out - Limelog("select() timed out after %d seconds\n", TEST_PORT_TIMEOUT_SEC); + // pollSockets() timed out + Limelog("Connection timed out after %d seconds\n", TEST_PORT_TIMEOUT_SEC); break; } // We know something was signalled. Now we just need to find out what. - for (i = 0; i < PORT_FLAGS_MAX_COUNT; i++) { - if (testPortFlags & (1 << i)) { - if (FD_ISSET(sockets[i], &writefds) || FD_ISSET(sockets[i], &exceptfds)) { - // A TCP socket was signalled - SOCKADDR_LEN len = sizeof(err); - getsockopt(sockets[i], SOL_SOCKET, SO_ERROR, (char*)&err, &len); - if (err != 0 || FD_ISSET(sockets[i], &exceptfds)) { - // Get the error code - err = (err != 0) ? err : LastSocketFail(); - } + for (i = 0; i < nfds; i++) { + if (pfds[i].revents != 0) { + int portIndex; - // The TCP test has completed for this port - testPortFlags &= ~(1 << i); - if (err == 0) { - // The TCP test was a success - failingPortFlags &= ~(1 << i); - - Limelog("TCP port %u test successful\n", LiGetPortFromPortFlagIndex(i)); - } - else { - Limelog("TCP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(i), err); + // This socket was signalled. Figure out what port it was. + for (portIndex = 0; portIndex < PORT_FLAGS_MAX_COUNT; portIndex++) { + if (sockets[portIndex] == pfds[i].fd) { + LC_ASSERT(testPortFlags & (1 << portIndex)); + break; } } - else if (FD_ISSET(sockets[i], &readfds)) { + + LC_ASSERT(portIndex != PORT_FLAGS_MAX_COUNT); + + if (LiGetProtocolFromPortFlagIndex(portIndex) == IPPROTO_UDP) { char buf[32]; // A UDP socket was signalled. This could be because we got // a packet from the test server, or it could be because we // received an ICMP error which will be given to us from // recvfrom(). - testPortFlags &= ~(1 << i); + testPortFlags &= ~(1 << portIndex); // Check if the socket can be successfully read now - err = recvfrom(sockets[i], buf, sizeof(buf), 0, NULL, NULL); + err = recvfrom(sockets[portIndex], buf, sizeof(buf), 0, NULL, NULL); if (err >= 0) { // The UDP test was a success. - failingPortFlags &= ~(1 << i); + failingPortFlags &= ~(1 << portIndex); - Limelog("UDP port %u test successful\n", LiGetPortFromPortFlagIndex(i)); + Limelog("UDP port %u test successful\n", LiGetPortFromPortFlagIndex(portIndex)); } else { err = LastSocketError(); - Limelog("UDP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(i), err); + Limelog("UDP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(portIndex), err); + } + } + else { + // A TCP socket was signalled + SOCKADDR_LEN len = sizeof(err); + getsockopt(sockets[portIndex], SOL_SOCKET, SO_ERROR, (char*)&err, &len); + if (err != 0 || (pfds[i].revents & POLLERR)) { + // Get the error code + err = (err != 0) ? err : LastSocketFail(); + } + + // The TCP test has completed for this port + testPortFlags &= ~(1 << portIndex); + if (err == 0) { + // The TCP test was a success + failingPortFlags &= ~(1 << portIndex); + + Limelog("TCP port %u test successful\n", LiGetPortFromPortFlagIndex(portIndex)); + } + else { + Limelog("TCP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(portIndex), err); } } } diff --git a/src/PlatformSockets.c b/src/PlatformSockets.c index 59a352b..8c59196 100644 --- a/src/PlatformSockets.c +++ b/src/PlatformSockets.c @@ -65,20 +65,77 @@ void setRecvTimeout(SOCKET s, int timeoutSec) { } } -int recvUdpSocket(SOCKET s, char* buffer, int size, int useSelect) { - fd_set readfds; - int err; +int pollSockets(struct pollfd* pollFds, int pollFdsCount, int timeoutMs) { +#ifdef LC_WINDOWS + // We could have used WSAPoll() but it has some nasty bugs + // https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ + // + // We'll emulate WSAPoll() with select(). Fortunately, Microsoft's definition + // of fd_set does not have the same stack corruption hazards that UNIX does. + fd_set readFds, writeFds, exceptFds; + int i, err; struct timeval tv; + + FD_ZERO(&readFds); + FD_ZERO(&writeFds); + FD_ZERO(&exceptFds); + + for (i = 0; i < pollFdsCount; i++) { + // Clear revents on input like poll() does + pollFds[i].revents = 0; + + if (pollFds[i].events & POLLIN) { + FD_SET(pollFds[i].fd, &readFds); + } + if (pollFds[i].events & POLLOUT) { + FD_SET(pollFds[i].fd, &writeFds); + + // Windows signals failed connections as an exception, + // while Linux signals them as writeable. + FD_SET(pollFds[i].fd, &exceptFds); + } + } + + tv.tv_sec = timeoutMs / 1000; + tv.tv_usec = (timeoutMs % 1000) * 1000; + + // nfds is unused on Windows + err = select(0, &readFds, &writeFds, &exceptFds, timeoutMs >= 0 ? &tv : NULL); + if (err <= 0) { + // Error or timeout + return err; + } + + for (i = 0; i < pollFdsCount; i++) { + if (FD_ISSET(pollFds[i].fd, &readFds)) { + pollFds[i].revents |= POLLRDNORM; + } + + if (FD_ISSET(pollFds[i].fd, &writeFds)) { + pollFds[i].revents |= POLLWRNORM; + } + + if (FD_ISSET(pollFds[i].fd, &exceptFds)) { + pollFds[i].revents |= POLLERR; + } + } + + return err; +#else + return poll(pollFds, pollFdsCount, timeoutMs); +#endif +} + +int recvUdpSocket(SOCKET s, char* buffer, int size, int useSelect) { + int err; if (useSelect) { - FD_ZERO(&readfds); - FD_SET(s, &readfds); + struct pollfd pfd; // Wait up to 100 ms for the socket to be readable - tv.tv_sec = 0; - tv.tv_usec = UDP_RECV_POLL_TIMEOUT_MS * 1000; - - err = select((int)(s) + 1, &readfds, NULL, NULL, &tv); + pfd.fd = s; + pfd.events = POLLIN; + err = pollSockets(&pfd, 1, UDP_RECV_POLL_TIMEOUT_MS); if (err <= 0) { // Return if an error or timeout occurs return err; @@ -263,39 +320,32 @@ SOCKET connectTcpSocket(struct sockaddr_storage* dstaddr, SOCKADDR_LEN addrlen, } if (nonBlocking) { - fd_set writefds, exceptfds; - struct timeval tv; - - FD_ZERO(&writefds); - FD_ZERO(&exceptfds); - FD_SET(s, &writefds); - FD_SET(s, &exceptfds); - - tv.tv_sec = timeoutSec; - tv.tv_usec = 0; - + struct pollfd pfd; + // Wait for the connection to complete or the timeout to elapse - err = select(s + 1, NULL, &writefds, &exceptfds, &tv); + pfd.fd = s; + pfd.events = POLLOUT; + err = pollSockets(&pfd, 1, timeoutSec * 1000); if (err < 0) { - // select() failed + // pollSockets() failed err = LastSocketError(); - Limelog("select() failed: %d\n", err); + Limelog("pollSockets() failed: %d\n", err); closeSocket(s); SetLastSocketError(err); return INVALID_SOCKET; } else if (err == 0) { - // select() timed out - Limelog("select() timed out after %d seconds\n", timeoutSec); + // pollSockets() timed out + Limelog("Connection timed out after %d seconds (TCP port %u)\n", timeoutSec, port); closeSocket(s); SetLastSocketError(ETIMEDOUT); return INVALID_SOCKET; } - else if (FD_ISSET(s, &writefds) || FD_ISSET(s, &exceptfds)) { + else { // The socket was signalled SOCKADDR_LEN len = sizeof(err); getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len); - if (err != 0 || FD_ISSET(s, &exceptfds)) { + if (err != 0 || (pfd.revents & POLLERR)) { // Get the error code err = (err != 0) ? err : LastSocketFail(); } diff --git a/src/PlatformSockets.h b/src/PlatformSockets.h index c24bcd0..a149512 100644 --- a/src/PlatformSockets.h +++ b/src/PlatformSockets.h @@ -29,6 +29,7 @@ typedef int SOCKADDR_LEN; #include #include #include +#include #define ioctlsocket ioctl #define LastSocketError() errno @@ -59,6 +60,7 @@ int setNonFatalRecvTimeoutMs(SOCKET s, int timeoutMs); void setRecvTimeout(SOCKET s, int timeoutSec); void closeSocket(SOCKET s); int isPrivateNetworkAddress(struct sockaddr_storage* address); +int pollSockets(struct pollfd* pollFds, int pollFdsCount, int timeoutMs); int initializePlatformSockets(void); void cleanupPlatformSockets(void);