Switch to poll() to be safe against stack corruption from exceeding FD_SETSIZE

Fortunately, the fd_set definition is not prone to stack corruption on Windows,
because FD_SETSIZE is the maximum number of sockets in a fd_set, not
the maximum value socket file descriptor that can be stored in a fd_set.

https://beesbuzz.biz/code/5739-The-problem-with-select-vs-poll
This commit is contained in:
Cameron Gutman
2020-07-26 19:53:52 -07:00
parent 84f2421fbf
commit 117a7c637f
4 changed files with 136 additions and 82 deletions

View File

@@ -149,7 +149,7 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref
if (err != EWOULDBLOCK && err != EAGAIN) {
Limelog("Failed to start async connect to TCP %u: %d\n", LiGetPortFromPortFlagIndex(i), err);
// Mask off this bit so we don't try to include it in select() below
// Mask off this bit so we don't try to include it in pollSockets() below
testPortFlags &= ~(1 << i);
}
}
@@ -165,7 +165,7 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref
err = (int)LastSocketError();
Limelog("Failed to send test packet to UDP %u: %d\n", LiGetPortFromPortFlagIndex(i), err);
// Mask off this bit so we don't try to include it in select() below
// Mask off this bit so we don't try to include it in pollSockets() below
testPortFlags &= ~(1 << i);
break;
@@ -177,103 +177,105 @@ unsigned int LiTestClientConnectivity(const char* testServer, unsigned short ref
}
}
// Continue to call select() until we have no more sockets to wait for,
// or our select() call times out.
// Continue to call pollSockets() until we have no more sockets to wait for,
// or our pollSockets() call times out.
while (testPortFlags != 0) {
SOCKET nfds;
fd_set readfds, writefds, exceptfds;
struct timeval tv;
int nfds;
struct pollfd pfds[PORT_FLAGS_MAX_COUNT];
nfds = 0;
FD_ZERO(&readfds);
FD_ZERO(&writefds);
FD_ZERO(&exceptfds);
// Fill out our FD sets
for (i = 0; i < PORT_FLAGS_MAX_COUNT; i++) {
if (testPortFlags & (1 << i)) {
pfds[nfds].fd = sockets[i];
if (LiGetProtocolFromPortFlagIndex(i) == IPPROTO_UDP) {
// Watch for readability on UDP sockets
FD_SET(sockets[i], &readfds);
if (sockets[i] + 1 > nfds) {
nfds = sockets[i] + 1;
}
pfds[nfds].events = POLLIN;
}
else {
// Watch for writeability or exceptions on TCP sockets
FD_SET(sockets[i], &writefds);
FD_SET(sockets[i], &exceptfds);
if (sockets[i] + 1 > nfds) {
nfds = sockets[i] + 1;
}
// Watch for writeability on TCP sockets
pfds[nfds].events = POLLOUT;
}
nfds++;
}
}
tv.tv_sec = TEST_PORT_TIMEOUT_SEC;
tv.tv_usec = 0;
// Wait for the to complete or the timeout to elapse.
// NB: The timeout resets each time we get a valid response on a port,
// but that's probably fine.
err = select((int)nfds, &readfds, &writefds, &exceptfds, &tv);
err = pollSockets(pfds, nfds, TEST_PORT_TIMEOUT_SEC * 1000);
if (err < 0) {
// select() failed
// pollSockets() failed
err = LastSocketError();
Limelog("select() failed: %d\n", err);
Limelog("pollSockets() failed: %d\n", err);
failingPortFlags = ML_TEST_RESULT_INCONCLUSIVE;
goto Exit;
}
else if (err == 0) {
// select() timed out
Limelog("select() timed out after %d seconds\n", TEST_PORT_TIMEOUT_SEC);
// pollSockets() timed out
Limelog("Connection timed out after %d seconds\n", TEST_PORT_TIMEOUT_SEC);
break;
}
// We know something was signalled. Now we just need to find out what.
for (i = 0; i < PORT_FLAGS_MAX_COUNT; i++) {
if (testPortFlags & (1 << i)) {
if (FD_ISSET(sockets[i], &writefds) || FD_ISSET(sockets[i], &exceptfds)) {
// A TCP socket was signalled
SOCKADDR_LEN len = sizeof(err);
getsockopt(sockets[i], SOL_SOCKET, SO_ERROR, (char*)&err, &len);
if (err != 0 || FD_ISSET(sockets[i], &exceptfds)) {
// Get the error code
err = (err != 0) ? err : LastSocketFail();
}
for (i = 0; i < nfds; i++) {
if (pfds[i].revents != 0) {
int portIndex;
// The TCP test has completed for this port
testPortFlags &= ~(1 << i);
if (err == 0) {
// The TCP test was a success
failingPortFlags &= ~(1 << i);
Limelog("TCP port %u test successful\n", LiGetPortFromPortFlagIndex(i));
}
else {
Limelog("TCP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(i), err);
// This socket was signalled. Figure out what port it was.
for (portIndex = 0; portIndex < PORT_FLAGS_MAX_COUNT; portIndex++) {
if (sockets[portIndex] == pfds[i].fd) {
LC_ASSERT(testPortFlags & (1 << portIndex));
break;
}
}
else if (FD_ISSET(sockets[i], &readfds)) {
LC_ASSERT(portIndex != PORT_FLAGS_MAX_COUNT);
if (LiGetProtocolFromPortFlagIndex(portIndex) == IPPROTO_UDP) {
char buf[32];
// A UDP socket was signalled. This could be because we got
// a packet from the test server, or it could be because we
// received an ICMP error which will be given to us from
// recvfrom().
testPortFlags &= ~(1 << i);
testPortFlags &= ~(1 << portIndex);
// Check if the socket can be successfully read now
err = recvfrom(sockets[i], buf, sizeof(buf), 0, NULL, NULL);
err = recvfrom(sockets[portIndex], buf, sizeof(buf), 0, NULL, NULL);
if (err >= 0) {
// The UDP test was a success.
failingPortFlags &= ~(1 << i);
failingPortFlags &= ~(1 << portIndex);
Limelog("UDP port %u test successful\n", LiGetPortFromPortFlagIndex(i));
Limelog("UDP port %u test successful\n", LiGetPortFromPortFlagIndex(portIndex));
}
else {
err = LastSocketError();
Limelog("UDP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(i), err);
Limelog("UDP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(portIndex), err);
}
}
else {
// A TCP socket was signalled
SOCKADDR_LEN len = sizeof(err);
getsockopt(sockets[portIndex], SOL_SOCKET, SO_ERROR, (char*)&err, &len);
if (err != 0 || (pfds[i].revents & POLLERR)) {
// Get the error code
err = (err != 0) ? err : LastSocketFail();
}
// The TCP test has completed for this port
testPortFlags &= ~(1 << portIndex);
if (err == 0) {
// The TCP test was a success
failingPortFlags &= ~(1 << portIndex);
Limelog("TCP port %u test successful\n", LiGetPortFromPortFlagIndex(portIndex));
}
else {
Limelog("TCP port %u test failed: %d\n", LiGetPortFromPortFlagIndex(portIndex), err);
}
}
}