#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/sysctl.h>
#include <sys/un.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <err.h>
#include <errno.h>
#include <pthread.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
static pthread_mutex_t therr_mtx = PTHREAD_MUTEX_INITIALIZER;
static void
therr(int eval, const char *fmt, ...)
{
va_list ap;
pthread_mutex_lock(&therr_mtx);
va_start(ap, fmt);
verr(eval, fmt, ap);
va_end(ap);
}
static void
therrx(int eval, const char *fmt, ...)
{
va_list ap;
pthread_mutex_lock(&therr_mtx);
va_start(ap, fmt);
verrx(eval, fmt, ap);
va_end(ap);
}
static void
therrc(int eval, int code, const char *fmt, ...)
{
va_list ap;
pthread_mutex_lock(&therr_mtx);
va_start(ap, fmt);
verrc(eval, code, fmt, ap);
va_end(ap);
}
#define PASSFD_NUM (4)
union msg_control {
struct cmsghdr cmsgh;
char control[CMSG_SPACE(sizeof(int) * PASSFD_NUM)];
};
static struct thr_pass_arg {
int s[2];
int passfd;
} *thr_pass_args;
static struct thr_gc_arg {
int passfd;
} *thr_gc_arg;
static void *
thr_send(void *arg)
{
union msg_control msg_control;
int iov_buf;
struct iovec iov;
struct msghdr msgh;
struct cmsghdr *cmsgh;
int *s = ((struct thr_pass_arg *)arg)->s;
int passfd = ((struct thr_pass_arg *)arg)->passfd;
while (1) {
iov_buf = 0;
iov.iov_base = &iov_buf;
iov.iov_len = sizeof(iov_buf);
msgh.msg_control = msg_control.control;
msgh.msg_controllen = sizeof(msg_control.control);
msgh.msg_iov = &iov;
msgh.msg_iovlen = 1;
msgh.msg_name = NULL;
msgh.msg_namelen = 0;
cmsgh = CMSG_FIRSTHDR(&msgh);
cmsgh->cmsg_len = CMSG_LEN(sizeof(int) * PASSFD_NUM);
cmsgh->cmsg_level = SOL_SOCKET;
cmsgh->cmsg_type = SCM_RIGHTS;
*((int *)CMSG_DATA(cmsgh) + 0) = s[0];
*((int *)CMSG_DATA(cmsgh) + 1) = s[1];
*((int *)CMSG_DATA(cmsgh) + 2) = passfd;
*((int *)CMSG_DATA(cmsgh) + 3) = passfd;
if (sendmsg(s[0], &msgh, 0) < 0) {
switch (errno) {
case EMFILE:
case ENOBUFS:
break;
default:
therr(1, "sendmsg");
}
}
}
return NULL;
}
static void *
thr_recv(void *arg)
{
union msg_control msg_control;
int iov_buf;
struct iovec iov;
struct msghdr msgh;
struct cmsghdr *cmsgh;
int i, fd;
int *s = ((struct thr_pass_arg *)arg)->s;
while (1) {
msg_control.cmsgh.cmsg_level = SOL_SOCKET;
msg_control.cmsgh.cmsg_type = SCM_RIGHTS;
msg_control.cmsgh.cmsg_len =
CMSG_LEN(sizeof(int) * PASSFD_NUM);
iov.iov_base = &iov_buf;
iov.iov_len = sizeof(iov_buf);
msgh.msg_control = msg_control.control;
msgh.msg_controllen = sizeof(msg_control.control);
msgh.msg_iov = &iov;
msgh.msg_iovlen = 1;
msgh.msg_name = NULL;
msgh.msg_namelen = 0;
if(recvmsg(s[1], &msgh, 0) < 0)
therr(1, "recvmsg");
if(!(cmsgh = CMSG_FIRSTHDR(&msgh)))
therrx(1, "bad cmsg header");
if(cmsgh->cmsg_level != SOL_SOCKET)
therrx(1, "bad cmsg level");
if(cmsgh->cmsg_type != SCM_RIGHTS)
therrx(1, "bad cmsg type");
if(cmsgh->cmsg_len != CMSG_LEN(sizeof(fd) * PASSFD_NUM))
therrx(1, "bad cmsg length");
for (i = 0; i < PASSFD_NUM; ++i) {
fd = *((int *)CMSG_DATA(cmsgh) + i);
close(fd);
}
}
return NULL;
}
static void *
thr_dispose(void *arg)
{
uint8_t buf[sizeof(union msg_control)];
int *s = ((struct thr_pass_arg *)arg)->s;
while (1) {
if (read(s[1], buf, sizeof(buf)) < 0)
therr(1, "read");
}
return NULL;
}
static void *
thr_gc(void *arg)
{
union msg_control msg_control;
int iov_buf;
struct iovec iov;
struct msghdr msgh;
struct cmsghdr *cmsgh;
int s[2], passfd = ((struct thr_gc_arg *)arg)->passfd;
while (1) {
if (socketpair(AF_UNIX, SOCK_DGRAM, 0, s) < 0)
therr(1, "socketpair");
iov_buf = 0;
iov.iov_base = &iov_buf;
iov.iov_len = sizeof(iov_buf);
msgh.msg_control = msg_control.control;
msgh.msg_controllen = sizeof(msg_control.control);
msgh.msg_iov = &iov;
msgh.msg_iovlen = 1;
msgh.msg_name = NULL;
msgh.msg_namelen = 0;
cmsgh = CMSG_FIRSTHDR(&msgh);
cmsgh->cmsg_len = CMSG_LEN(sizeof(int) * PASSFD_NUM);
cmsgh->cmsg_level = SOL_SOCKET;
cmsgh->cmsg_type = SCM_RIGHTS;
*((int *)CMSG_DATA(cmsgh) + 0) = s[0];
*((int *)CMSG_DATA(cmsgh) + 1) = s[1];
*((int *)CMSG_DATA(cmsgh) + 2) = passfd;
*((int *)CMSG_DATA(cmsgh) + 3) = passfd;
if (sendmsg(s[0], &msgh, 0) < 0) {
switch (errno) {
case EMFILE:
case ENOBUFS:
break;
default:
therr(1, "sendmsg");
}
}
close(s[0]);
close(s[1]);
}
return NULL;
}
int
main(int argc, char *argv[])
{
struct timespec testtime = {
.tv_sec = 60,
.tv_nsec = 0,
};
int mib[2], ncpu;
size_t len;
pthread_t thr;
int i, error;
if (argc == 2 && !strcmp(argv[1], "--infinite"))
testtime.tv_sec = (10 * 365 * 86400);
mib[0] = CTL_HW;
mib[1] = HW_NCPUONLINE;
len = sizeof(ncpu);
if (sysctl(mib, 2, &ncpu, &len, NULL, 0) < 0)
err(1, "sysctl");
if (ncpu <= 0)
errx(1, "Wrong number of CPUs online: %d", ncpu);
if (!(thr_pass_args = calloc(ncpu, sizeof(*thr_pass_args))))
err(1, "malloc");
if (!(thr_gc_arg = malloc(sizeof(*thr_gc_arg))))
err(1, "malloc");
for (i = 0; i < ncpu; ++i) {
if (socketpair(AF_UNIX, SOCK_DGRAM, 0, thr_pass_args[i].s) < 0)
err(1, "socketpair");
thr_pass_args[i].passfd = thr_pass_args[i].s[0];
}
thr_gc_arg->passfd = thr_pass_args[0].s[0];
for (i = 0; i < ncpu; ++i) {
error = pthread_create(&thr, NULL,
thr_send, &thr_pass_args[i]);
if (error)
therrc(1, error, "pthread_create");
error = pthread_create(&thr, NULL,
thr_recv, &thr_pass_args[i]);
if (error)
therrc(1, error, "pthread_create");
error = pthread_create(&thr, NULL,
thr_dispose, &thr_pass_args[i]);
if (error)
therrc(1, error, "pthread_create");
}
if ((error = pthread_create(&thr, NULL, thr_gc, thr_gc_arg)))
therrc(1, error, "pthread_create");
nanosleep(&testtime, NULL);
return 0;
}