From c36efb0c451b3cc55d6ebf634af451f9d0b542b9 Mon Sep 17 00:00:00 2001 From: Nick Mathewson Date: Tue, 12 May 2009 16:17:32 -0400 Subject: [PATCH] Use a mutex to protect the count of open sockets. This matters because a cpuworker can close its socket when it finishes. Cpuworker typically runs in another thread, so without a lock here, we can have a race condition and get confused about how many sockets are open. Possible fix for bug 939. --- ChangeLog | 3 +++ src/common/compat.c | 52 ++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/ChangeLog b/ChangeLog index 9359915c1e..c4608212ed 100644 --- a/ChangeLog +++ b/ChangeLog @@ -5,6 +5,9 @@ Changes in version 0.2.1.15??? - ????-??-?? Bugfix on 0.2.0.9-alpha. - Provide a more useful log message if bug 977 (related to buffer freelists) ever reappears, and do not crash right away. + - Protect the count of open sockets with a mutex, so we can't + corrupt it when two threads are closing or opening sockets at once. + Fix for bug 939. Bugfix on 0.2.0.1-alpha. Changes in version 0.2.1.14-rc - 2009-04-12 diff --git a/src/common/compat.c b/src/common/compat.c index 82957722c9..51794c762c 100644 --- a/src/common/compat.c +++ b/src/common/compat.c @@ -676,6 +676,23 @@ static int max_socket = -1; * eventdns and libevent.) */ static int n_sockets_open = 0; +/** Mutex to protect open_sockets, max_socket, and n_sockets_open. */ +static tor_mutex_t *socket_accounting_mutex = NULL; + +static INLINE void +socket_accounting_lock(void) +{ + if (PREDICT_UNLIKELY(!socket_accounting_mutex)) + socket_accounting_mutex = tor_mutex_new(); + tor_mutex_acquire(socket_accounting_mutex); +} + +static INLINE void +socket_accounting_unlock(void) +{ + tor_mutex_release(socket_accounting_mutex); +} + /** As close(), but guaranteed to work for sockets across platforms (including * Windows, where close()ing a socket doesn't work. Returns 0 on success, -1 * on failure. */ @@ -683,15 +700,7 @@ int tor_close_socket(int s) { int r = 0; -#ifdef DEBUG_SOCKET_COUNTING - if (s > max_socket || ! bitarray_is_set(open_sockets, s)) { - log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_" - "socket(), or that was already closed or something.", s); - } else { - tor_assert(open_sockets && s <= max_socket); - bitarray_clear(open_sockets, s); - } -#endif + /* On Windows, you have to call close() on fds returned by open(), * and closesocket() on fds returned by socket(). On Unix, everything * gets close()'d. We abstract this difference by always using @@ -703,6 +712,17 @@ tor_close_socket(int s) #else r = close(s); #endif + + socket_accounting_lock(); +#ifdef DEBUG_SOCKET_COUNTING + if (s > max_socket || ! bitarray_is_set(open_sockets, s)) { + log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_" + "socket(), or that was already closed or something.", s); + } else { + tor_assert(open_sockets && s <= max_socket); + bitarray_clear(open_sockets, s); + } +#endif if (r == 0) { --n_sockets_open; } else { @@ -717,9 +737,11 @@ tor_close_socket(int s) #endif r = -1; } + if (n_sockets_open < 0) log_warn(LD_BUG, "Our socket count is below zero: %d. Please submit a " "bug report.", n_sockets_open); + socket_accounting_unlock(); return r; } @@ -754,8 +776,10 @@ tor_open_socket(int domain, int type, int protocol) { int s = socket(domain, type, protocol); if (s >= 0) { + socket_accounting_lock(); ++n_sockets_open; mark_socket_open(s); + socket_accounting_unlock(); } return s; } @@ -766,8 +790,10 @@ tor_accept_socket(int sockfd, struct sockaddr *addr, socklen_t *len) { int s = accept(sockfd, addr, len); if (s >= 0) { + socket_accounting_lock(); ++n_sockets_open; mark_socket_open(s); + socket_accounting_unlock(); } return s; } @@ -776,7 +802,11 @@ tor_accept_socket(int sockfd, struct sockaddr *addr, socklen_t *len) int get_n_open_sockets(void) { - return n_sockets_open; + int n; + socket_accounting_lock(); + n = n_sockets_open; + socket_accounting_unlock(); + return n; } /** Turn socket into a nonblocking socket. @@ -817,6 +847,7 @@ tor_socketpair(int family, int type, int protocol, int fd[2]) int r; r = socketpair(family, type, protocol, fd); if (r == 0) { + socket_accounting_lock(); if (fd[0] >= 0) { ++n_sockets_open; mark_socket_open(fd[0]); @@ -825,6 +856,7 @@ tor_socketpair(int family, int type, int protocol, int fd[2]) ++n_sockets_open; mark_socket_open(fd[1]); } + socket_accounting_unlock(); } return r < 0 ? -errno : r; #else -- 2.11.4.GIT