#include <assert.h>
#include <netinet/in.h>
#include <nghttp3/nghttp3.h>
#include <openssl/err.h>
#include <openssl/quic.h>
#include <openssl/ssl.h>
#include <unistd.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <sys/socket.h>
#ifndef PATH_MAX
#define PATH_MAX 255
#endif
#define nghttp3_arraylen(A) (sizeof(A) / sizeof(*(A)))
#define NULL_PAYLOAD "12345678901234567890"
static uint8_t *nulldata = (uint8_t *)NULL_PAYLOAD;
static size_t nulldata_sz = sizeof(NULL_PAYLOAD) - 1;
static nghttp3_settings settings;
static const nghttp3_mem *mem;
static nghttp3_callbacks callbacks = { 0 };
struct ssl_id {
SSL *s;
uint64_t id;
int status;
};
#define CLIENTUNIOPEN 0x01
#define CLIENTCLOSED 0x02
#define CLIENTBIDIOPEN 0x04
#define SERVERUNIOPEN 0x08
#define SERVERCLOSED 0x10
#define TOBEREMOVED 0x20
#define ISLISTENER 0x40
#define ISCONNECTION 0x80
#define MAXSSL_IDS 20
#define MAXURL 255
struct h3ssl {
struct ssl_id ssl_ids[MAXSSL_IDS];
int end_headers_received;
int datadone;
int has_uni;
int close_done;
int close_wait;
int done;
int new_conn;
int received_from_two;
int restart;
uint64_t id_bidi;
char *fileprefix;
char url[MAXURL];
uint8_t *ptr_data;
size_t ldata;
int offset_data;
};
static void make_nv(nghttp3_nv *nv, const char *name, const char *value)
{
nv->name = (uint8_t *)name;
nv->value = (uint8_t *)value;
nv->namelen = strlen(name);
nv->valuelen = strlen(value);
nv->flags = NGHTTP3_NV_FLAG_NONE;
}
static void init_ids(struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
char *prior_fileprefix = h3ssl->fileprefix;
if (h3ssl->ptr_data != NULL && h3ssl->ptr_data != nulldata)
free(h3ssl->ptr_data);
memset(h3ssl, 0, sizeof(struct h3ssl));
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++)
ssl_ids[i].id = UINT64_MAX;
h3ssl->id_bidi = UINT64_MAX;
h3ssl->fileprefix = prior_fileprefix;
}
static void reuse_h3ssl(struct h3ssl *h3ssl)
{
h3ssl->end_headers_received = 0;
h3ssl->datadone = 0;
h3ssl->close_done = 0;
h3ssl->close_wait = 0;
h3ssl->done = 0;
memset(h3ssl->url, '\0', sizeof(h3ssl->url));
if (h3ssl->ptr_data != NULL && h3ssl->ptr_data != nulldata)
free(h3ssl->ptr_data);
h3ssl->ptr_data = NULL;
h3ssl->offset_data = 0;
h3ssl->ldata = 0;
}
static void add_id_status(uint64_t id, SSL *ssl, struct h3ssl *h3ssl, int status)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].s == NULL) {
ssl_ids[i].s = ssl;
ssl_ids[i].id = id;
ssl_ids[i].status = status;
return;
}
}
printf("Oops too many streams to add!!!\n");
exit(1);
}
static void add_id(uint64_t id, SSL *ssl, struct h3ssl *h3ssl)
{
add_id_status(id, ssl, h3ssl, 0);
}
static void add_ids_listener(SSL *ssl, struct h3ssl *h3ssl)
{
add_id_status(UINT64_MAX, ssl, h3ssl, ISLISTENER);
}
static void add_ids_connection(struct h3ssl *h3ssl, SSL *ssl)
{
add_id_status(UINT64_MAX, ssl, h3ssl, ISCONNECTION);
}
static SSL *get_ids_connection(struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].status & ISCONNECTION) {
printf("get_ids_connection\n");
return ssl_ids[i].s;
}
}
return NULL;
}
static void replace_ids_connection(struct h3ssl *h3ssl, SSL *oldstream, SSL *newstream)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].status & ISCONNECTION && ssl_ids[i].s == oldstream) {
printf("replace_ids_connection\n");
ssl_ids[i].s = newstream;
}
}
}
static void remove_marked_ids(struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].status & TOBEREMOVED) {
printf("remove_id %llu\n", (unsigned long long)ssl_ids[i].id);
SSL_free(ssl_ids[i].s);
ssl_ids[i].s = NULL;
ssl_ids[i].id = UINT64_MAX;
ssl_ids[i].status = 0;
return;
}
}
}
static void set_id_status(uint64_t id, int status, struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].id == id) {
printf("set_id_status: %llu to %d\n", (unsigned long long)ssl_ids[i].id, status);
ssl_ids[i].status = ssl_ids[i].status | status;
return;
}
}
printf("Oops can't set status, can't find stream!!!\n");
assert(0);
}
static int get_id_status(uint64_t id, struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].id == id) {
printf("get_id_status: %llu to %d\n",
(unsigned long long)ssl_ids[i].id, ssl_ids[i].status);
return ssl_ids[i].status;
}
}
printf("Oops can't get status, can't find stream!!!\n");
assert(0);
return -1;
}
static int are_all_clientid_closed(struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].id == UINT64_MAX)
continue;
printf("are_all_clientid_closed: %llu status %d : %d\n",
(unsigned long long)ssl_ids[i].id, ssl_ids[i].status, CLIENTUNIOPEN | CLIENTCLOSED);
if (ssl_ids[i].status & CLIENTUNIOPEN) {
if (ssl_ids[i].status & CLIENTCLOSED) {
printf("are_all_clientid_closed: %llu closed\n",
(unsigned long long)ssl_ids[i].id);
SSL_free(ssl_ids[i].s);
ssl_ids[i].s = NULL;
ssl_ids[i].id = UINT64_MAX;
continue;
}
printf("are_all_clientid_closed: %llu open\n", (unsigned long long)ssl_ids[i].id);
return 0;
}
}
return 1;
}
static void close_all_ids(struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].id == UINT64_MAX)
continue;
SSL_free(ssl_ids[i].s);
ssl_ids[i].s = NULL;
ssl_ids[i].id = UINT64_MAX;
}
}
static int on_recv_header(nghttp3_conn *conn, int64_t stream_id, int32_t token,
nghttp3_rcbuf *name, nghttp3_rcbuf *value,
uint8_t flags, void *user_data,
void *stream_user_data)
{
nghttp3_vec vname, vvalue;
struct h3ssl *h3ssl = (struct h3ssl *)user_data;
vname = nghttp3_rcbuf_get_buf(name);
vvalue = nghttp3_rcbuf_get_buf(value);
fwrite(vname.base, vname.len, 1, stdout);
fprintf(stdout, ": ");
fwrite(vvalue.base, vvalue.len, 1, stdout);
fprintf(stdout, "\n");
if (token == NGHTTP3_QPACK_TOKEN__PATH) {
int len = (((vvalue.len) < (MAXURL)) ? (vvalue.len) : (MAXURL));
memset(h3ssl->url, 0, sizeof(h3ssl->url));
if (vvalue.base[0] == '/') {
if (vvalue.base[1] == '\0') {
strncpy(h3ssl->url, "index.html", MAXURL);
} else {
memcpy(h3ssl->url, vvalue.base + 1, len - 1);
h3ssl->url[len - 1] = '\0';
}
} else {
memcpy(h3ssl->url, vvalue.base, len);
}
}
return 0;
}
static int on_end_headers(nghttp3_conn *conn, int64_t stream_id, int fin,
void *user_data, void *stream_user_data)
{
struct h3ssl *h3ssl = (struct h3ssl *)user_data;
fprintf(stderr, "on_end_headers!\n");
h3ssl->end_headers_received = 1;
return 0;
}
static int on_recv_data(nghttp3_conn *conn, int64_t stream_id,
const uint8_t *data, size_t datalen,
void *conn_user_data, void *stream_user_data)
{
fprintf(stderr, "on_recv_data! %ld\n", (unsigned long)datalen);
fprintf(stderr, "on_recv_data! %.*s\n", (int)datalen, data);
return 0;
}
static int on_end_stream(nghttp3_conn *h3conn, int64_t stream_id,
void *conn_user_data, void *stream_user_data)
{
struct h3ssl *h3ssl = (struct h3ssl *)conn_user_data;
printf("on_end_stream!\n");
h3ssl->done = 1;
return 0;
}
static int quic_server_read(nghttp3_conn *h3conn, SSL *stream, uint64_t id, struct h3ssl *h3ssl)
{
int ret, r;
uint8_t msg2[16000];
size_t l = sizeof(msg2);
if (!SSL_has_pending(stream))
return 0;
ret = SSL_read(stream, msg2, l);
if (ret <= 0) {
fprintf(stderr, "SSL_read %d on %llu failed\n",
SSL_get_error(stream, ret),
(unsigned long long)id);
switch (SSL_get_error(stream, ret)) {
case SSL_ERROR_WANT_READ:
return 0;
case SSL_ERROR_ZERO_RETURN:
return 1;
default:
ERR_print_errors_fp(stderr);
return -1;
}
return -1;
}
if (!h3ssl->received_from_two && id != 2) {
r = nghttp3_conn_read_stream(h3conn, id, msg2, ret, 0);
} else {
r = ret;
}
printf("nghttp3_conn_read_stream used %d of %d on %llu\n", r,
ret, (unsigned long long)id);
if (r != ret) {
if (!nghttp3_err_is_fatal(r)) {
printf("nghttp3_conn_read_stream used %d of %d (not fatal) on %llu\n", r,
ret, (unsigned long long)id);
if (id == 2)
h3ssl->received_from_two = 1;
return 1;
}
return -1;
}
return 1;
}
static int quic_server_h3streams(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
{
SSL *rstream = NULL;
SSL *pstream = NULL;
SSL *cstream = NULL;
SSL *conn;
uint64_t r_streamid, p_streamid, c_streamid;
conn = get_ids_connection(h3ssl);
if (conn == NULL) {
fprintf(stderr, "quic_server_h3streams no connection\n");
fflush(stderr);
return -1;
}
rstream = SSL_new_stream(conn, SSL_STREAM_FLAG_UNI);
if (rstream != NULL) {
printf("=> Opened on %llu\n",
(unsigned long long)SSL_get_stream_id(rstream));
} else {
fprintf(stderr, "=> Stream == NULL!\n");
goto err;
}
pstream = SSL_new_stream(conn, SSL_STREAM_FLAG_UNI);
if (pstream != NULL) {
printf("=> Opened on %llu\n",
(unsigned long long)SSL_get_stream_id(pstream));
} else {
fprintf(stderr, "=> Stream == NULL!\n");
goto err;
}
cstream = SSL_new_stream(conn, SSL_STREAM_FLAG_UNI);
if (cstream != NULL) {
fprintf(stderr, "=> Opened on %llu\n",
(unsigned long long)SSL_get_stream_id(cstream));
fflush(stderr);
} else {
fprintf(stderr, "=> Stream == NULL!\n");
goto err;
}
r_streamid = SSL_get_stream_id(rstream);
p_streamid = SSL_get_stream_id(pstream);
c_streamid = SSL_get_stream_id(cstream);
if (nghttp3_conn_bind_qpack_streams(h3conn, p_streamid, r_streamid)) {
fprintf(stderr, "nghttp3_conn_bind_qpack_streams failed!\n");
goto err;
}
if (nghttp3_conn_bind_control_stream(h3conn, c_streamid)) {
fprintf(stderr, "nghttp3_conn_bind_qpack_streams failed!\n");
goto err;
}
printf("control: %llu enc %llu dec %llu\n",
(unsigned long long)c_streamid,
(unsigned long long)p_streamid,
(unsigned long long)r_streamid);
add_id(SSL_get_stream_id(rstream), rstream, h3ssl);
add_id(SSL_get_stream_id(pstream), pstream, h3ssl);
add_id(SSL_get_stream_id(cstream), cstream, h3ssl);
return 0;
err:
fflush(stderr);
SSL_free(rstream);
SSL_free(pstream);
SSL_free(cstream);
return -1;
}
static int read_from_ssl_ids(nghttp3_conn **curh3conn, struct h3ssl *h3ssl)
{
int hassomething = 0, i;
struct ssl_id *ssl_ids = h3ssl->ssl_ids;
SSL_POLL_ITEM items[MAXSSL_IDS] = { 0 }, *item = items;
static const struct timeval nz_timeout = { 0, 0 };
size_t result_count = SIZE_MAX;
int numitem = 0, ret;
uint64_t processed_event = 0;
int has_ids_to_remove = 0;
nghttp3_conn *h3conn = *curh3conn;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].s != NULL) {
item->desc = SSL_as_poll_descriptor(ssl_ids[i].s);
item->events = UINT64_MAX;
item->revents = UINT64_MAX;
numitem++;
item++;
}
}
ret = SSL_poll(items, numitem, sizeof(SSL_POLL_ITEM), &nz_timeout,
SSL_POLL_FLAG_NO_HANDLE_EVENTS, &result_count);
if (!ret) {
fprintf(stderr, "SSL_poll failed\n");
printf("SSL_poll failed\n");
return -1;
}
printf("read_from_ssl_ids %ld events\n", (unsigned long)result_count);
if (result_count == 0) {
return 0;
}
h3ssl->new_conn = 0;
h3ssl->restart = 0;
h3ssl->done = 0;
for (i = 0, item = items; i < numitem; i++, item++) {
SSL *s;
if (item->revents == SSL_POLL_EVENT_NONE)
continue;
processed_event = 0;
s = item->desc.value.ssl;
if (item->revents & SSL_POLL_EVENT_IC) {
SSL *conn = SSL_accept_connection(item->desc.value.ssl, 0);
SSL *oldconn;
printf("SSL_accept_connection\n");
if (conn == NULL) {
fprintf(stderr, "error while accepting connection\n");
ret = -1;
goto err;
}
oldconn = get_ids_connection(h3ssl);
if (oldconn != NULL) {
printf("SSL_accept_connection closing previous\n");
SSL_free(oldconn);
replace_ids_connection(h3ssl, oldconn, conn);
reuse_h3ssl(h3ssl);
close_all_ids(h3ssl);
h3ssl->id_bidi = UINT64_MAX;
h3ssl->has_uni = 0;
} else {
printf("SSL_accept_connection first connection\n");
add_ids_connection(h3ssl, conn);
}
h3ssl->new_conn = 1;
nghttp3_conn_del(*curh3conn);
nghttp3_settings_default(&settings);
if (nghttp3_conn_server_new(curh3conn, &callbacks, &settings, mem,
h3ssl)) {
fprintf(stderr, "nghttp3_conn_client_new failed!\n");
exit(1);
}
h3conn = *curh3conn;
hassomething++;
if (!SSL_set_incoming_stream_policy(conn,
SSL_INCOMING_STREAM_POLICY_ACCEPT, 0)) {
fprintf(stderr, "error while setting inccoming stream policy\n");
ret = -1;
goto err;
}
printf("SSL_accept_connection\n");
processed_event = processed_event | SSL_POLL_EVENT_IC;
}
if ((item->revents & SSL_POLL_EVENT_ISB) || (item->revents & SSL_POLL_EVENT_ISU)) {
SSL *stream = SSL_accept_stream(item->desc.value.ssl, 0);
uint64_t new_id;
int r;
if (stream == NULL) {
ret = -1;
goto err;
}
new_id = SSL_get_stream_id(stream);
printf("=> Received connection on %lld %d\n", (unsigned long long)new_id,
SSL_get_stream_type(stream));
add_id(new_id, stream, h3ssl);
if (h3ssl->close_wait) {
printf("in close_wait so we will have a new request\n");
reuse_h3ssl(h3ssl);
h3ssl->restart = 1;
}
if (SSL_get_stream_type(stream) == SSL_STREAM_TYPE_BIDI) {
if (h3ssl->id_bidi != UINT64_MAX) {
set_id_status(h3ssl->id_bidi, TOBEREMOVED, h3ssl);
has_ids_to_remove++;
}
h3ssl->id_bidi = new_id;
reuse_h3ssl(h3ssl);
h3ssl->restart = 1;
} else {
set_id_status(new_id, CLIENTUNIOPEN, h3ssl);
}
r = quic_server_read(h3conn, stream, new_id, h3ssl);
if (r == -1) {
ret = -1;
goto err;
}
if (r == 1)
hassomething++;
if (item->revents & SSL_POLL_EVENT_ISB)
processed_event = processed_event | SSL_POLL_EVENT_ISB;
if (item->revents & SSL_POLL_EVENT_ISU)
processed_event = processed_event | SSL_POLL_EVENT_ISU;
}
if (item->revents & SSL_POLL_EVENT_OSB) {
processed_event = processed_event | SSL_POLL_EVENT_OSB;
printf("Create bidi?\n");
}
if (item->revents & SSL_POLL_EVENT_OSU) {
printf("Create uni?\n");
processed_event = processed_event | SSL_POLL_EVENT_OSU;
if (!h3ssl->has_uni) {
printf("Create uni\n");
ret = quic_server_h3streams(h3conn, h3ssl);
if (ret == -1) {
fprintf(stderr, "quic_server_h3streams failed!\n");
goto err;
}
h3ssl->has_uni = 1;
hassomething++;
}
}
if (item->revents & SSL_POLL_EVENT_EC) {
printf("Connection terminating\n");
printf("Connection terminating restart %d\n", h3ssl->restart);
if (!h3ssl->close_done) {
h3ssl->close_done = 1;
} else {
h3ssl->done = 1;
}
hassomething++;
processed_event = processed_event | SSL_POLL_EVENT_EC;
}
if (item->revents & SSL_POLL_EVENT_ECD) {
printf("Connection terminated\n");
h3ssl->done = 1;
hassomething++;
processed_event = processed_event | SSL_POLL_EVENT_ECD;
}
if (item->revents & SSL_POLL_EVENT_R) {
uint64_t id = UINT64_MAX;
int r;
id = SSL_get_stream_id(item->desc.value.ssl);
printf("revent READ on %llu\n", (unsigned long long)id);
r = quic_server_read(h3conn, s, id, h3ssl);
if (r == 0) {
uint8_t msg[1];
size_t l = sizeof(msg);
r = SSL_read(s, msg, l);
printf("SSL_read tells %d\n", r);
if (r > 0) {
ret = -1;
goto err;
}
r = SSL_get_error(s, r);
if (r != SSL_ERROR_ZERO_RETURN) {
ret = -1;
goto err;
}
set_id_status(id, TOBEREMOVED, h3ssl);
has_ids_to_remove++;
continue;
}
if (r == -1) {
ret = -1;
goto err;
}
hassomething++;
processed_event = processed_event | SSL_POLL_EVENT_R;
}
if (item->revents & SSL_POLL_EVENT_ER) {
uint64_t id = UINT64_MAX;
int status;
id = SSL_get_stream_id(item->desc.value.ssl);
status = get_id_status(id, h3ssl);
printf("revent exception READ on %llu\n", (unsigned long long)id);
if (status & CLIENTUNIOPEN) {
set_id_status(id, CLIENTCLOSED, h3ssl);
hassomething++;
}
processed_event = processed_event | SSL_POLL_EVENT_ER;
}
if (item->revents & SSL_POLL_EVENT_W) {
processed_event = processed_event | SSL_POLL_EVENT_W;
}
if (item->revents & SSL_POLL_EVENT_EW) {
uint64_t id = UINT64_MAX;
int status;
id = SSL_get_stream_id(item->desc.value.ssl);
status = get_id_status(id, h3ssl);
if (status & SERVERCLOSED) {
printf("both sides closed on %llu\n", (unsigned long long)id);
set_id_status(id, TOBEREMOVED, h3ssl);
has_ids_to_remove++;
hassomething++;
}
processed_event = processed_event | SSL_POLL_EVENT_EW;
}
if (item->revents != processed_event) {
uint64_t id = UINT64_MAX;
id = SSL_get_stream_id(item->desc.value.ssl);
printf("revent %llu (%d) on %llu NOT PROCESSED!\n",
(unsigned long long)item->revents, SSL_POLL_EVENT_W,
(unsigned long long)id);
}
}
ret = hassomething;
err:
if (has_ids_to_remove)
remove_marked_ids(h3ssl);
return ret;
}
static void handle_events_from_ids(struct h3ssl *h3ssl)
{
struct ssl_id *ssl_ids = h3ssl->ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].s != NULL && (ssl_ids[i].status & ISCONNECTION || ssl_ids[i].status & ISLISTENER)) {
if (SSL_handle_events(ssl_ids[i].s))
ERR_print_errors_fp(stderr);
}
}
}
static size_t get_file_length(struct h3ssl *h3ssl)
{
char filename[PATH_MAX];
struct stat st;
memset(filename, 0, PATH_MAX);
if (h3ssl->fileprefix != NULL)
strcat(filename, h3ssl->fileprefix);
strcat(filename, h3ssl->url);
if (strcmp(h3ssl->url, "big") == 0) {
printf("big!!!\n");
return (size_t)INT_MAX;
}
if (stat(filename, &st) == 0) {
if (S_ISREG(st.st_mode)) {
printf("get_file_length %s %lld\n", filename, (unsigned long long)st.st_size);
return (size_t)st.st_size;
}
}
printf("Can't get_file_length %s\n", filename);
return 0;
}
static char *get_file_data(struct h3ssl *h3ssl)
{
char filename[PATH_MAX];
size_t size = get_file_length(h3ssl);
char *res;
int fd;
if (size == 0)
return NULL;
memset(filename, 0, PATH_MAX);
if (h3ssl->fileprefix != NULL)
strcat(filename, h3ssl->fileprefix);
strcat(filename, h3ssl->url);
res = malloc(size + 1);
res[size] = '\0';
fd = open(filename, O_RDONLY);
if (read(fd, res, size) == -1) {
close(fd);
free(res);
return NULL;
}
close(fd);
printf("read from %s : %zu\n", filename, size);
return res;
}
static nghttp3_ssize step_read_data(nghttp3_conn *conn, int64_t stream_id,
nghttp3_vec *vec, size_t veccnt,
uint32_t *pflags, void *user_data,
void *stream_user_data)
{
struct h3ssl *h3ssl = (struct h3ssl *)user_data;
if (h3ssl->datadone) {
*pflags = NGHTTP3_DATA_FLAG_EOF;
return 0;
}
printf("step_read_data for %s %zu\n", h3ssl->url, h3ssl->ldata);
if (h3ssl->ldata <= 4096) {
vec[0].base = &(h3ssl->ptr_data[h3ssl->offset_data]);
vec[0].len = h3ssl->ldata;
h3ssl->datadone++;
*pflags = NGHTTP3_DATA_FLAG_EOF;
} else {
vec[0].base = &(h3ssl->ptr_data[h3ssl->offset_data]);
vec[0].len = 4096;
if (h3ssl->ldata == INT_MAX) {
printf("big = endless!\n");
} else {
h3ssl->offset_data = h3ssl->offset_data + 4096;
h3ssl->ldata = h3ssl->ldata - 4096;
}
}
return 1;
}
static int quic_server_write(struct h3ssl *h3ssl, uint64_t streamid,
uint8_t *buff, size_t len, uint64_t flags,
size_t *written)
{
struct ssl_id *ssl_ids;
int i;
ssl_ids = h3ssl->ssl_ids;
for (i = 0; i < MAXSSL_IDS; i++) {
if (ssl_ids[i].id == streamid) {
if (!SSL_write_ex2(ssl_ids[i].s, buff, len, flags, written) || *written != len) {
fprintf(stderr, "couldn't write on connection\n");
ERR_print_errors_fp(stderr);
return 0;
}
printf("written %lld on %lld flags %lld\n", (unsigned long long)len,
(unsigned long long)streamid, (unsigned long long)flags);
return 1;
}
}
printf("quic_server_write %lld on %lld (NOT FOUND!)\n", (unsigned long long)len,
(unsigned long long)streamid);
return 0;
}
#define OSSL_NELEM(x) (sizeof(x) / sizeof((x)[0]))
static const unsigned char alpn_ossltest[] = { 5, 'h', '3', '-', '2',
'9', 2, 'h', '3' };
static int select_alpn(SSL *ssl, const unsigned char **out,
unsigned char *out_len, const unsigned char *in,
unsigned int in_len, void *arg)
{
if (SSL_select_next_proto((unsigned char **)out, out_len, alpn_ossltest,
sizeof(alpn_ossltest), in,
in_len)
!= OPENSSL_NPN_NEGOTIATED)
return SSL_TLSEXT_ERR_ALERT_FATAL;
return SSL_TLSEXT_ERR_OK;
}
static SSL_CTX *create_ctx(const char *cert_path, const char *key_path)
{
SSL_CTX *ctx;
ctx = SSL_CTX_new(OSSL_QUIC_server_method());
if (ctx == NULL)
goto err;
if (SSL_CTX_use_certificate_chain_file(ctx, cert_path) <= 0) {
fprintf(stderr, "couldn't load certificate file: %s\n", cert_path);
goto err;
}
if (SSL_CTX_use_PrivateKey_file(ctx, key_path, SSL_FILETYPE_PEM) <= 0) {
fprintf(stderr, "couldn't load key file: %s\n", key_path);
goto err;
}
if (!SSL_CTX_check_private_key(ctx)) {
fprintf(stderr, "private key check failed\n");
goto err;
}
SSL_CTX_set_alpn_select_cb(ctx, select_alpn, NULL);
return ctx;
err:
SSL_CTX_free(ctx);
return NULL;
}
static int create_socket(uint16_t port)
{
int fd = -1;
struct sockaddr_in sa = { 0 };
if ((fd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) < 0) {
fprintf(stderr, "cannot create socket");
goto err;
}
sa.sin_family = AF_INET;
sa.sin_port = htons(port);
if (bind(fd, (const struct sockaddr *)&sa, sizeof(sa)) < 0) {
fprintf(stderr, "cannot bind to %u\n", port);
goto err;
}
return fd;
err:
if (fd >= 0)
BIO_closesocket(fd);
return -1;
}
static int wait_for_activity(SSL *ssl)
{
int sock, isinfinite;
fd_set read_fd, write_fd;
struct timeval tv;
struct timeval *tvp = NULL;
if ((sock = SSL_get_fd(ssl)) == -1) {
fprintf(stderr, "Unable to get file descriptor");
return -1;
}
FD_ZERO(&read_fd);
FD_ZERO(&write_fd);
if (SSL_net_write_desired(ssl))
FD_SET(sock, &write_fd);
if (SSL_net_read_desired(ssl))
FD_SET(sock, &read_fd);
FD_SET(sock, &read_fd);
if (SSL_get_event_timeout(ssl, &tv, &isinfinite) && !isinfinite)
tvp = &tv;
return (select(sock + 1, &read_fd, &write_fd, NULL, tvp));
}
static int run_quic_server(SSL_CTX *ctx, int fd)
{
int ok = 0;
int hassomething = 0;
SSL *listener = NULL;
nghttp3_conn *h3conn = NULL;
struct h3ssl h3ssl;
SSL *ssl;
char *fileprefix = getenv("FILEPREFIX");
if ((listener = SSL_new_listener(ctx, 0)) == NULL)
goto err;
if (!SSL_set_fd(listener, fd))
goto err;
if (!SSL_listen(listener))
goto err;
if (!SSL_set_blocking_mode(listener, 0))
goto err;
callbacks.recv_header = on_recv_header;
callbacks.end_headers = on_end_headers;
callbacks.recv_data = on_recv_data;
callbacks.end_stream = on_end_stream;
mem = nghttp3_mem_default();
for (;;) {
nghttp3_nv resp[10];
size_t num_nv;
nghttp3_data_reader dr;
int ret;
int numtimeout;
char slength[22];
int hasnothing;
init_ids(&h3ssl);
h3ssl.fileprefix = fileprefix;
printf("listener: %p\n", (void *)listener);
add_ids_listener(listener, &h3ssl);
if (!hassomething) {
printf("waiting on socket\n");
fflush(stdout);
ret = wait_for_activity(listener);
if (ret == -1) {
fprintf(stderr, "wait_for_activity failed!\n");
goto err;
}
}
newconn:
printf("process_server starting...\n");
fflush(stdout);
restart:
numtimeout = 0;
num_nv = 0;
while (!h3ssl.end_headers_received) {
if (!hassomething) {
if (wait_for_activity(listener) == 0) {
printf("waiting for end_headers_received timeout %d\n", numtimeout);
numtimeout++;
if (numtimeout == 25)
goto err;
}
handle_events_from_ids(&h3ssl);
}
hassomething = read_from_ssl_ids(&h3conn, &h3ssl);
if (hassomething == -1) {
fprintf(stderr, "read_from_ssl_ids hassomething failed\n");
goto err;
} else if (hassomething == 0) {
printf("read_from_ssl_ids hassomething nothing...\n");
} else {
numtimeout = 0;
printf("read_from_ssl_ids hassomething %d...\n", hassomething);
if (h3ssl.close_done) {
break;
}
h3ssl.restart = 0;
}
}
if (h3ssl.close_done) {
printf("Other side close without request\n");
goto wait_close;
}
printf("end_headers_received!!!\n");
if (!h3ssl.has_uni) {
printf("Create uni\n");
if (quic_server_h3streams(h3conn, &h3ssl) == -1) {
fprintf(stderr, "quic_server_h3streams failed!\n");
goto err;
}
h3ssl.has_uni = 1;
}
make_nv(&resp[num_nv++], ":status", "200");
h3ssl.ldata = get_file_length(&h3ssl);
if (h3ssl.ldata == 0) {
h3ssl.ptr_data = nulldata;
h3ssl.ldata = nulldata_sz;
sprintf(slength, "%zu", h3ssl.ldata);
make_nv(&resp[num_nv++], "content-type", "text/html");
} else if (h3ssl.ldata == INT_MAX) {
sprintf(slength, "%zu", h3ssl.ldata);
h3ssl.ptr_data = (uint8_t *)malloc(4096);
memset(h3ssl.ptr_data, 'A', 4096);
} else {
sprintf(slength, "%zu", h3ssl.ldata);
h3ssl.ptr_data = (uint8_t *)get_file_data(&h3ssl);
if (h3ssl.ptr_data == NULL)
abort();
printf("before nghttp3_conn_submit_response on %llu for %s ...\n",
(unsigned long long)h3ssl.id_bidi, h3ssl.url);
if (strstr(h3ssl.url, ".png"))
make_nv(&resp[num_nv++], "content-type", "image/png");
else if (strstr(h3ssl.url, ".ico"))
make_nv(&resp[num_nv++], "content-type", "image/vnd.microsoft.icon");
else if (strstr(h3ssl.url, ".htm"))
make_nv(&resp[num_nv++], "content-type", "text/html");
else
make_nv(&resp[num_nv++], "content-type", "application/octet-stream");
make_nv(&resp[num_nv++], "content-length", slength);
}
dr.read_data = step_read_data;
if (nghttp3_conn_submit_response(h3conn, h3ssl.id_bidi, resp, num_nv, &dr)) {
fprintf(stderr, "nghttp3_conn_submit_response failed!\n");
goto err;
}
printf("nghttp3_conn_submit_response on %llu...\n", (unsigned long long)h3ssl.id_bidi);
for (;;) {
nghttp3_vec vec[256];
nghttp3_ssize sveccnt;
int fin, i;
int64_t streamid;
sveccnt = nghttp3_conn_writev_stream(h3conn, &streamid, &fin, vec,
nghttp3_arraylen(vec));
if (sveccnt <= 0) {
printf("nghttp3_conn_writev_stream done: %ld stream: %llu fin %d\n",
(long int)sveccnt,
(unsigned long long)streamid,
fin);
if (streamid != -1 && fin) {
printf("Sending end data on %llu fin %d\n",
(unsigned long long)streamid, fin);
nghttp3_conn_add_write_offset(h3conn, streamid, 0);
continue;
}
if (!h3ssl.datadone)
goto err;
else
break;
}
printf("nghttp3_conn_writev_stream: %ld fin: %d\n", (long int)sveccnt, fin);
for (i = 0; i < sveccnt; i++) {
size_t numbytes = vec[i].len;
int flagwrite = 0;
printf("quic_server_write on %llu for %ld\n",
(unsigned long long)streamid, (unsigned long)vec[i].len);
if (fin && i == sveccnt - 1)
flagwrite = SSL_WRITE_FLAG_CONCLUDE;
if (!quic_server_write(&h3ssl, streamid, vec[i].base,
vec[i].len, flagwrite, &numbytes)) {
fprintf(stderr, "quic_server_write failed!\n");
goto err;
}
}
if (nghttp3_conn_add_write_offset(
h3conn, streamid,
(size_t)nghttp3_vec_len(vec, (size_t)sveccnt))) {
fprintf(stderr, "nghttp3_conn_add_write_offset failed!\n");
goto err;
}
}
printf("nghttp3_conn_submit_response DONE!!!\n");
if (h3ssl.datadone) {
if (!h3ssl.close_done) {
set_id_status(h3ssl.id_bidi, SERVERCLOSED, &h3ssl);
h3ssl.close_wait = 1;
}
} else {
printf("nghttp3_conn_submit_response still not finished\n");
}
wait_close:
hasnothing = 0;
for (;;) {
if (!hasnothing) {
SSL *newssl = get_ids_connection(&h3ssl);
printf("hasnothing nothing WAIT %d!!!\n", h3ssl.close_done);
if (newssl == NULL)
newssl = listener;
ret = wait_for_activity(newssl);
if (ret == -1)
goto err;
if (ret == 0)
printf("hasnothing timeout\n");
handle_events_from_ids(&h3ssl);
}
hasnothing = read_from_ssl_ids(&h3conn, &h3ssl);
if (hasnothing == -1) {
printf("hasnothing failed\n");
break;
} else if (hasnothing == 0) {
printf("hasnothing nothing\n");
continue;
} else {
printf("hasnothing something\n");
if (h3ssl.done) {
printf("hasnothing something... DONE\n");
hassomething = 1;
break;
}
if (h3ssl.new_conn) {
printf("hasnothing something... NEW CONN\n");
h3ssl.new_conn = 0;
goto newconn;
}
if (h3ssl.restart) {
printf("hasnothing something... RESTART\n");
h3ssl.restart = 0;
goto restart;
}
if (are_all_clientid_closed(&h3ssl)) {
printf("hasnothing something... DONE other side closed\n");
hassomething = 0;
break;
}
}
}
close_all_ids(&h3ssl);
ssl = get_ids_connection(&h3ssl);
if (ssl != NULL) {
SSL_free(ssl);
replace_ids_connection(&h3ssl, ssl, NULL);
}
hassomething = 0;
}
ok = 1;
err:
if (!ok)
ERR_print_errors_fp(stderr);
SSL_free(listener);
return ok;
}
int main(int argc, char **argv)
{
int rc = 1;
SSL_CTX *ctx = NULL;
int fd = -1;
unsigned long port;
if (argc < 4) {
fprintf(stderr, "usage: %s <port> <server.crt> <server.key>\n",
argv[0]);
goto err;
}
if ((ctx = create_ctx(argv[2], argv[3])) == NULL)
goto err;
port = strtoul(argv[1], NULL, 0);
if (port == 0 || port > UINT16_MAX) {
fprintf(stderr, "invalid port: %lu\n", port);
goto err;
}
if ((fd = create_socket((uint16_t)port)) < 0)
goto err;
if (!run_quic_server(ctx, fd))
goto err;
rc = 0;
err:
if (rc != 0)
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
if (fd != -1)
BIO_closesocket(fd);
return rc;
}