#include "internal/quic_reactor.h"
#include "internal/common.h"
#include "internal/thread_arch.h"
#include <assert.h>
#if defined(OPENSSL_SYS_WINDOWS)
#include <winsock2.h>
#include <mstcpip.h>
#include <mswsock.h>
#endif
static void rtor_notify_other_threads(QUIC_REACTOR *rtor);
int ossl_quic_reactor_init(QUIC_REACTOR *rtor,
void (*tick_cb)(QUIC_TICK_RESULT *res, void *arg,
uint32_t flags),
void *tick_cb_arg,
CRYPTO_MUTEX *mutex,
OSSL_TIME initial_tick_deadline,
uint64_t flags)
{
rtor->poll_r.type = BIO_POLL_DESCRIPTOR_TYPE_NONE;
rtor->poll_w.type = BIO_POLL_DESCRIPTOR_TYPE_NONE;
rtor->net_read_desired = 0;
rtor->net_write_desired = 0;
rtor->can_poll_r = 0;
rtor->can_poll_w = 0;
rtor->tick_deadline = initial_tick_deadline;
rtor->tick_cb = tick_cb;
rtor->tick_cb_arg = tick_cb_arg;
rtor->mutex = mutex;
rtor->cur_blocking_waiters = 0;
if ((flags & QUIC_REACTOR_FLAG_USE_NOTIFIER) != 0) {
if (!ossl_rio_notifier_init(&rtor->notifier))
return 0;
if ((rtor->notifier_cv = ossl_crypto_condvar_new()) == NULL) {
ossl_rio_notifier_cleanup(&rtor->notifier);
return 0;
}
rtor->have_notifier = 1;
} else {
rtor->have_notifier = 0;
}
return 1;
}
void ossl_quic_reactor_cleanup(QUIC_REACTOR *rtor)
{
if (rtor == NULL)
return;
if (rtor->have_notifier) {
ossl_rio_notifier_cleanup(&rtor->notifier);
rtor->have_notifier = 0;
ossl_crypto_condvar_free(&rtor->notifier_cv);
}
}
#if defined(OPENSSL_SYS_WINDOWS)
#if defined(__MINGW32__) && !defined(SIO_UDP_NETRESET)
#define SIO_UDP_NETRESET _WSAIOW(IOC_VENDOR, 15)
#endif
static void rtor_configure_winsock(BIO_POLL_DESCRIPTOR *bpd)
{
BOOL bNewBehavior = FALSE;
DWORD dwBytesReturned = 0;
if (bpd->type == BIO_POLL_DESCRIPTOR_TYPE_SOCK_FD) {
WSAIoctl(bpd->value.fd, SIO_UDP_CONNRESET, &bNewBehavior,
sizeof(bNewBehavior), NULL, 0, &dwBytesReturned, NULL, NULL);
WSAIoctl(bpd->value.fd, SIO_UDP_NETRESET, &bNewBehavior,
sizeof(bNewBehavior), NULL, 0, &dwBytesReturned, NULL, NULL);
}
}
#endif
void ossl_quic_reactor_set_poll_r(QUIC_REACTOR *rtor, const BIO_POLL_DESCRIPTOR *r)
{
if (r == NULL)
rtor->poll_r.type = BIO_POLL_DESCRIPTOR_TYPE_NONE;
else
rtor->poll_r = *r;
#if defined(OPENSSL_SYS_WINDOWS)
rtor_configure_winsock(&rtor->poll_r);
#endif
rtor->can_poll_r
= ossl_quic_reactor_can_support_poll_descriptor(rtor, &rtor->poll_r);
}
void ossl_quic_reactor_set_poll_w(QUIC_REACTOR *rtor, const BIO_POLL_DESCRIPTOR *w)
{
if (w == NULL)
rtor->poll_w.type = BIO_POLL_DESCRIPTOR_TYPE_NONE;
else
rtor->poll_w = *w;
#if defined(OPENSSL_SYS_WINDOWS)
rtor_configure_winsock(&rtor->poll_w);
#endif
rtor->can_poll_w
= ossl_quic_reactor_can_support_poll_descriptor(rtor, &rtor->poll_w);
}
const BIO_POLL_DESCRIPTOR *ossl_quic_reactor_get_poll_r(const QUIC_REACTOR *rtor)
{
return &rtor->poll_r;
}
const BIO_POLL_DESCRIPTOR *ossl_quic_reactor_get_poll_w(const QUIC_REACTOR *rtor)
{
return &rtor->poll_w;
}
int ossl_quic_reactor_can_support_poll_descriptor(const QUIC_REACTOR *rtor,
const BIO_POLL_DESCRIPTOR *d)
{
return d->type == BIO_POLL_DESCRIPTOR_TYPE_SOCK_FD;
}
int ossl_quic_reactor_can_poll_r(const QUIC_REACTOR *rtor)
{
return rtor->can_poll_r;
}
int ossl_quic_reactor_can_poll_w(const QUIC_REACTOR *rtor)
{
return rtor->can_poll_w;
}
int ossl_quic_reactor_net_read_desired(QUIC_REACTOR *rtor)
{
return rtor->net_read_desired;
}
int ossl_quic_reactor_net_write_desired(QUIC_REACTOR *rtor)
{
return rtor->net_write_desired;
}
OSSL_TIME ossl_quic_reactor_get_tick_deadline(QUIC_REACTOR *rtor)
{
return rtor->tick_deadline;
}
int ossl_quic_reactor_tick(QUIC_REACTOR *rtor, uint32_t flags)
{
QUIC_TICK_RESULT res = { 0 };
rtor->tick_cb(&res, rtor->tick_cb_arg, flags);
rtor->net_read_desired = res.net_read_desired;
rtor->net_write_desired = res.net_write_desired;
rtor->tick_deadline = res.tick_deadline;
if (res.notify_other_threads)
rtor_notify_other_threads(rtor);
return 1;
}
RIO_NOTIFIER *ossl_quic_reactor_get0_notifier(QUIC_REACTOR *rtor)
{
return rtor->have_notifier ? &rtor->notifier : NULL;
}
static int poll_two_fds(int rfd, int rfd_want_read,
int wfd, int wfd_want_write,
int notify_rfd,
OSSL_TIME deadline,
CRYPTO_MUTEX *mutex)
{
#if defined(OPENSSL_SYS_WINDOWS) || !defined(POLLIN)
fd_set rfd_set, wfd_set, efd_set;
OSSL_TIME now, timeout;
struct timeval tv, *ptv;
int maxfd, pres;
#ifndef OPENSSL_SYS_WINDOWS
if (rfd >= FD_SETSIZE || wfd >= FD_SETSIZE)
return 0;
#endif
FD_ZERO(&rfd_set);
FD_ZERO(&wfd_set);
FD_ZERO(&efd_set);
if (rfd != INVALID_SOCKET && rfd_want_read)
openssl_fdset(rfd, &rfd_set);
if (wfd != INVALID_SOCKET && wfd_want_write)
openssl_fdset(wfd, &wfd_set);
if (rfd != INVALID_SOCKET)
openssl_fdset(rfd, &efd_set);
if (wfd != INVALID_SOCKET)
openssl_fdset(wfd, &efd_set);
if (notify_rfd != INVALID_SOCKET) {
openssl_fdset(notify_rfd, &rfd_set);
openssl_fdset(notify_rfd, &efd_set);
}
maxfd = rfd;
if (wfd > maxfd)
maxfd = wfd;
if (notify_rfd > maxfd)
maxfd = notify_rfd;
if (!ossl_assert(rfd != INVALID_SOCKET || wfd != INVALID_SOCKET
|| !ossl_time_is_infinite(deadline)))
return 0;
#if defined(OPENSSL_THREADS)
if (mutex != NULL)
ossl_crypto_mutex_unlock(mutex);
#endif
do {
if (ossl_time_is_infinite(deadline)) {
ptv = NULL;
} else {
now = ossl_time_now();
timeout = ossl_time_subtract(deadline, now);
tv = ossl_time_to_timeval(timeout);
ptv = &tv;
}
pres = select(maxfd + 1, &rfd_set, &wfd_set, &efd_set, ptv);
} while (pres == -1 && get_last_socket_error_is_eintr());
#if defined(OPENSSL_THREADS)
if (mutex != NULL)
ossl_crypto_mutex_lock(mutex);
#endif
return pres < 0 ? 0 : 1;
#else
int pres, timeout_ms;
OSSL_TIME now, timeout;
struct pollfd pfds[3] = { 0 };
size_t npfd = 0;
if (rfd == wfd) {
pfds[npfd].fd = rfd;
pfds[npfd].events = (rfd_want_read ? POLLIN : 0)
| (wfd_want_write ? POLLOUT : 0);
if (rfd >= 0 && pfds[npfd].events != 0)
++npfd;
} else {
pfds[npfd].fd = rfd;
pfds[npfd].events = (rfd_want_read ? POLLIN : 0);
if (rfd >= 0 && pfds[npfd].events != 0)
++npfd;
pfds[npfd].fd = wfd;
pfds[npfd].events = (wfd_want_write ? POLLOUT : 0);
if (wfd >= 0 && pfds[npfd].events != 0)
++npfd;
}
if (notify_rfd >= 0) {
pfds[npfd].fd = notify_rfd;
pfds[npfd].events = POLLIN;
++npfd;
}
if (!ossl_assert(npfd != 0 || !ossl_time_is_infinite(deadline)))
return 0;
#if defined(OPENSSL_THREADS)
if (mutex != NULL)
ossl_crypto_mutex_unlock(mutex);
#endif
do {
if (ossl_time_is_infinite(deadline)) {
timeout_ms = -1;
} else {
now = ossl_time_now();
timeout = ossl_time_subtract(deadline, now);
timeout_ms = ossl_time2ms(timeout);
}
pres = poll(pfds, npfd, timeout_ms);
} while (pres == -1 && get_last_socket_error_is_eintr());
#if defined(OPENSSL_THREADS)
if (mutex != NULL)
ossl_crypto_mutex_lock(mutex);
#endif
return pres < 0 ? 0 : 1;
#endif
}
static int poll_descriptor_to_fd(const BIO_POLL_DESCRIPTOR *d, int *fd)
{
if (d == NULL || d->type == BIO_POLL_DESCRIPTOR_TYPE_NONE) {
*fd = INVALID_SOCKET;
return 1;
}
if (d->type != BIO_POLL_DESCRIPTOR_TYPE_SOCK_FD
|| d->value.fd == INVALID_SOCKET)
return 0;
*fd = d->value.fd;
return 1;
}
static int poll_two_descriptors(const BIO_POLL_DESCRIPTOR *r, int r_want_read,
const BIO_POLL_DESCRIPTOR *w, int w_want_write,
int notify_rfd,
OSSL_TIME deadline,
CRYPTO_MUTEX *mutex)
{
int rfd, wfd;
if (!poll_descriptor_to_fd(r, &rfd)
|| !poll_descriptor_to_fd(w, &wfd))
return 0;
return poll_two_fds(rfd, r_want_read, wfd, w_want_write,
notify_rfd, deadline, mutex);
}
static void rtor_notify_other_threads(QUIC_REACTOR *rtor)
{
if (!rtor->have_notifier)
return;
if (rtor->cur_blocking_waiters == 0)
return;
if (!rtor->signalled_notifier) {
ossl_rio_notifier_signal(&rtor->notifier);
rtor->signalled_notifier = 1;
}
while (rtor->signalled_notifier)
ossl_crypto_condvar_wait(rtor->notifier_cv, rtor->mutex);
}
int ossl_quic_reactor_block_until_pred(QUIC_REACTOR *rtor,
int (*pred)(void *arg), void *pred_arg,
uint32_t flags)
{
int res, net_read_desired, net_write_desired, notifier_fd;
OSSL_TIME tick_deadline;
notifier_fd
= (rtor->have_notifier ? ossl_rio_notifier_as_fd(&rtor->notifier)
: INVALID_SOCKET);
for (;;) {
if ((flags & SKIP_FIRST_TICK) != 0)
flags &= ~SKIP_FIRST_TICK;
else
ossl_quic_reactor_tick(rtor, 0);
if ((res = pred(pred_arg)) != 0)
return res;
net_read_desired = ossl_quic_reactor_net_read_desired(rtor);
net_write_desired = ossl_quic_reactor_net_write_desired(rtor);
tick_deadline = ossl_quic_reactor_get_tick_deadline(rtor);
if (!net_read_desired && !net_write_desired
&& ossl_time_is_infinite(tick_deadline))
return 0;
ossl_quic_reactor_enter_blocking_section(rtor);
res = poll_two_descriptors(ossl_quic_reactor_get_poll_r(rtor),
net_read_desired,
ossl_quic_reactor_get_poll_w(rtor),
net_write_desired,
notifier_fd,
tick_deadline,
rtor->mutex);
ossl_quic_reactor_leave_blocking_section(rtor);
if (!res)
return 0;
}
return res;
}
void ossl_quic_reactor_enter_blocking_section(QUIC_REACTOR *rtor)
{
++rtor->cur_blocking_waiters;
}
void ossl_quic_reactor_leave_blocking_section(QUIC_REACTOR *rtor)
{
assert(rtor->cur_blocking_waiters > 0);
--rtor->cur_blocking_waiters;
if (rtor->have_notifier && rtor->signalled_notifier) {
if (rtor->cur_blocking_waiters == 0) {
ossl_rio_notifier_unsignal(&rtor->notifier);
rtor->signalled_notifier = 0;
ossl_crypto_condvar_broadcast(rtor->notifier_cv);
} else {
while (rtor->signalled_notifier)
ossl_crypto_condvar_wait(rtor->notifier_cv, rtor->mutex);
}
}
}