From 159d1f6c02569091c7a48bdb2e2e824b844a1902 Mon Sep 17 00:00:00 2001 From: Rich Felker Date: Tue, 13 Dec 2022 18:39:44 -0500 Subject: semaphores: fix missed wakes from ABA bug in waiter count logic because the has-waiters state in the semaphore value futex word is only representable when the value is zero (the special value -1 represents "0 with potential new waiters"), it's lost if intervening operations make the semaphore value positive again. this creates an ABA issue in sem_post, whereby the post uses a stale waiters count rather than re-evaluating it, skipping the futex wake if the stale count was zero. the fix here is based on a proposal by Alexey Izbyshev, with minor changes to eliminate costly new spurious wake syscalls. the basic idea is to replace the special value -1 with a sticky waiters bit (repurposing the sign bit) preserved under both wait and post. any post that takes place with the waiters bit set will perform a futex wake. to be useful, the waiters bit needs to be removable, and to remove it safely, we perform a broadcast wake instead of a normal single-task wake whenever removing the bit. this lets any un-accounted-for waiters wake and re-add the waiters bit if they still need it. there are multiple possible choices for when to perform this broadcast, but the optimal choice seems to be doing it whenever the observed waiters count is less than two (semantically, this means exactly one, but we might see a stale count of zero). in this case, the expected number of threads to be woken is one, with exactly the same cost as a non-broadcast wake. --- src/thread/sem_getvalue.c | 3 ++- src/thread/sem_post.c | 12 ++++++++---- src/thread/sem_timedwait.c | 10 ++++++---- src/thread/sem_trywait.c | 6 +++--- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/thread/sem_getvalue.c b/src/thread/sem_getvalue.c index d9d83071..c0b7762d 100644 --- a/src/thread/sem_getvalue.c +++ b/src/thread/sem_getvalue.c @@ -1,8 +1,9 @@ #include +#include int sem_getvalue(sem_t *restrict sem, int *restrict valp) { int val = sem->__val[0]; - *valp = val < 0 ? 0 : val; + *valp = val & SEM_VALUE_MAX; return 0; } diff --git a/src/thread/sem_post.c b/src/thread/sem_post.c index 31e3293d..5c2517f2 100644 --- a/src/thread/sem_post.c +++ b/src/thread/sem_post.c @@ -1,17 +1,21 @@ #include +#include #include "pthread_impl.h" int sem_post(sem_t *sem) { - int val, waiters, priv = sem->__val[2]; + int val, new, waiters, priv = sem->__val[2]; do { val = sem->__val[0]; waiters = sem->__val[1]; - if (val == SEM_VALUE_MAX) { + if ((val & SEM_VALUE_MAX) == SEM_VALUE_MAX) { errno = EOVERFLOW; return -1; } - } while (a_cas(sem->__val, val, val+1+(val<0)) != val); - if (val<0 || waiters) __wake(sem->__val, 1, priv); + new = val + 1; + if (waiters <= 1) + new &= ~0x80000000; + } while (a_cas(sem->__val, val, new) != val); + if (val<0) __wake(sem->__val, waiters>1 ? 1 : -1, priv); return 0; } diff --git a/src/thread/sem_timedwait.c b/src/thread/sem_timedwait.c index 58d3ebfe..aa67376c 100644 --- a/src/thread/sem_timedwait.c +++ b/src/thread/sem_timedwait.c @@ -1,4 +1,5 @@ #include +#include #include "pthread_impl.h" static void cleanup(void *p) @@ -13,14 +14,15 @@ int sem_timedwait(sem_t *restrict sem, const struct timespec *restrict at) if (!sem_trywait(sem)) return 0; int spins = 100; - while (spins-- && sem->__val[0] <= 0 && !sem->__val[1]) a_spin(); + while (spins-- && !(sem->__val[0] & SEM_VALUE_MAX) && !sem->__val[1]) + a_spin(); while (sem_trywait(sem)) { - int r; + int r, priv = sem->__val[2]; a_inc(sem->__val+1); - a_cas(sem->__val, 0, -1); + a_cas(sem->__val, 0, 0x80000000); pthread_cleanup_push(cleanup, (void *)(sem->__val+1)); - r = __timedwait_cp(sem->__val, -1, CLOCK_REALTIME, at, sem->__val[2]); + r = __timedwait_cp(sem->__val, 0x80000000, CLOCK_REALTIME, at, priv); pthread_cleanup_pop(1); if (r) { errno = r; diff --git a/src/thread/sem_trywait.c b/src/thread/sem_trywait.c index 04edf46b..beb435da 100644 --- a/src/thread/sem_trywait.c +++ b/src/thread/sem_trywait.c @@ -1,12 +1,12 @@ #include +#include #include "pthread_impl.h" int sem_trywait(sem_t *sem) { int val; - while ((val=sem->__val[0]) > 0) { - int new = val-1-(val==1 && sem->__val[1]); - if (a_cas(sem->__val, val, new)==val) return 0; + while ((val=sem->__val[0]) & SEM_VALUE_MAX) { + if (a_cas(sem->__val, val, val-1)==val) return 0; } errno = EAGAIN; return -1; -- cgit v1.2.3-70-g09d2