#include <err.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "tls13_handshake.h"
#define MAX_FLAGS (UINT8_MAX + 1)
struct child {
enum tls13_message_type mt;
uint8_t flag;
uint8_t forced;
uint8_t illegal;
};
static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
[CLIENT_HELLO] = {
{
.mt = SERVER_HELLO_RETRY_REQUEST,
},
{
.mt = SERVER_HELLO,
.flag = WITHOUT_HRR,
},
},
[SERVER_HELLO_RETRY_REQUEST] = {
{
.mt = CLIENT_HELLO_RETRY,
},
},
[CLIENT_HELLO_RETRY] = {
{
.mt = SERVER_HELLO,
},
},
[SERVER_HELLO] = {
{
.mt = SERVER_ENCRYPTED_EXTENSIONS,
},
},
[SERVER_ENCRYPTED_EXTENSIONS] = {
{
.mt = SERVER_CERTIFICATE_REQUEST,
},
{ .mt = SERVER_CERTIFICATE,
.flag = WITHOUT_CR,
},
{
.mt = SERVER_FINISHED,
.flag = WITH_PSK,
},
},
[SERVER_CERTIFICATE_REQUEST] = {
{
.mt = SERVER_CERTIFICATE,
},
},
[SERVER_CERTIFICATE] = {
{
.mt = SERVER_CERTIFICATE_VERIFY,
},
},
[SERVER_CERTIFICATE_VERIFY] = {
{
.mt = SERVER_FINISHED,
},
},
[SERVER_FINISHED] = {
{
.mt = CLIENT_FINISHED,
.forced = WITHOUT_CR | WITH_PSK,
},
{
.mt = CLIENT_CERTIFICATE,
.illegal = WITHOUT_CR | WITH_PSK,
},
},
[CLIENT_CERTIFICATE] = {
{
.mt = CLIENT_FINISHED,
},
{
.mt = CLIENT_CERTIFICATE_VERIFY,
.flag = WITH_CCV,
},
},
[CLIENT_CERTIFICATE_VERIFY] = {
{
.mt = CLIENT_FINISHED,
},
},
[CLIENT_FINISHED] = {
{
.mt = APPLICATION_DATA,
},
},
[APPLICATION_DATA] = {
{
.mt = 0,
},
},
};
const size_t stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
void build_table(enum tls13_message_type
table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
struct child current, struct child end,
struct child path[], uint8_t flags, unsigned int depth);
size_t count_handshakes(void);
void edge(enum tls13_message_type start,
enum tls13_message_type end, uint8_t flag);
const char *flag2str(uint8_t flag);
void flag_label(uint8_t flag);
void forced_edges(enum tls13_message_type start,
enum tls13_message_type end, uint8_t forced);
int generate_graphics(void);
void fprint_entry(FILE *stream,
enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
uint8_t flags);
void fprint_flags(FILE *stream, uint8_t flags);
const char *mt2str(enum tls13_message_type mt);
void usage(void);
int verify_table(enum tls13_message_type
table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print);
const char *
flag2str(uint8_t flag)
{
const char *ret;
if (flag & (flag - 1))
errx(1, "more than one bit is set");
switch (flag) {
case INITIAL:
ret = "INITIAL";
break;
case NEGOTIATED:
ret = "NEGOTIATED";
break;
case WITHOUT_CR:
ret = "WITHOUT_CR";
break;
case WITHOUT_HRR:
ret = "WITHOUT_HRR";
break;
case WITH_PSK:
ret = "WITH_PSK";
break;
case WITH_CCV:
ret = "WITH_CCV";
break;
case WITH_0RTT:
ret = "WITH_0RTT";
break;
default:
ret = "UNKNOWN";
}
return ret;
}
const char *
mt2str(enum tls13_message_type mt)
{
const char *ret;
switch (mt) {
case INVALID:
ret = "INVALID";
break;
case CLIENT_HELLO:
ret = "CLIENT_HELLO";
break;
case CLIENT_HELLO_RETRY:
ret = "CLIENT_HELLO_RETRY";
break;
case CLIENT_END_OF_EARLY_DATA:
ret = "CLIENT_END_OF_EARLY_DATA";
break;
case CLIENT_CERTIFICATE:
ret = "CLIENT_CERTIFICATE";
break;
case CLIENT_CERTIFICATE_VERIFY:
ret = "CLIENT_CERTIFICATE_VERIFY";
break;
case CLIENT_FINISHED:
ret = "CLIENT_FINISHED";
break;
case SERVER_HELLO:
ret = "SERVER_HELLO";
break;
case SERVER_HELLO_RETRY_REQUEST:
ret = "SERVER_HELLO_RETRY_REQUEST";
break;
case SERVER_ENCRYPTED_EXTENSIONS:
ret = "SERVER_ENCRYPTED_EXTENSIONS";
break;
case SERVER_CERTIFICATE:
ret = "SERVER_CERTIFICATE";
break;
case SERVER_CERTIFICATE_VERIFY:
ret = "SERVER_CERTIFICATE_VERIFY";
break;
case SERVER_CERTIFICATE_REQUEST:
ret = "SERVER_CERTIFICATE_REQUEST";
break;
case SERVER_FINISHED:
ret = "SERVER_FINISHED";
break;
case APPLICATION_DATA:
ret = "APPLICATION_DATA";
break;
case TLS13_NUM_MESSAGE_TYPES:
ret = "TLS13_NUM_MESSAGE_TYPES";
break;
default:
ret = "UNKNOWN";
break;
}
return ret;
}
void
fprint_flags(FILE *stream, uint8_t flags)
{
int first = 1, i;
if (flags == 0) {
fprintf(stream, "%s", flag2str(flags));
return;
}
for (i = 0; i < 8; i++) {
uint8_t set = flags & (1U << i);
if (set) {
fprintf(stream, "%s%s", first ? "" : " | ",
flag2str(set));
first = 0;
}
}
}
void
fprint_entry(FILE *stream,
enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
{
int i;
fprintf(stream, "\t[");
fprint_flags(stream, flags);
fprintf(stream, "] = {\n");
for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
if (path[i] == 0)
break;
fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
}
fprintf(stream, "\t},\n");
}
void
edge(enum tls13_message_type start, enum tls13_message_type end,
uint8_t flag)
{
printf("\t%s -> %s", mt2str(start), mt2str(end));
flag_label(flag);
printf(";\n");
}
void
flag_label(uint8_t flag)
{
if (flag)
printf(" [label=\"%s\"]", flag2str(flag));
}
void
forced_edges(enum tls13_message_type start, enum tls13_message_type end,
uint8_t forced)
{
uint8_t forced_flag, i;
if (forced == 0)
return;
for (i = 0; i < 8; i++) {
forced_flag = forced & (1U << i);
if (forced_flag)
edge(start, end, forced_flag);
}
}
int
generate_graphics(void)
{
enum tls13_message_type start, end;
unsigned int child;
uint8_t flag;
uint8_t forced;
printf("digraph G {\n");
printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
for (child = 0; stateinfo[start][child].mt != 0; child++) {
end = stateinfo[start][child].mt;
flag = stateinfo[start][child].flag;
forced = stateinfo[start][child].forced;
if (forced == 0)
edge(start, end, flag);
else
forced_edges(start, end, forced);
}
}
printf("}\n");
return 0;
}
extern enum tls13_message_type handshakes[][TLS13_NUM_MESSAGE_TYPES];
extern size_t handshake_count;
size_t
count_handshakes(void)
{
size_t ret = 0, i;
for (i = 0; i < handshake_count; i++) {
if (handshakes[i][0] != INVALID)
ret++;
}
return ret;
}
void
build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
struct child current, struct child end, struct child path[], uint8_t flags,
unsigned int depth)
{
unsigned int i;
if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
errx(1, "recursed too deeply");
path[depth++] = current;
flags |= current.flag;
if (current.mt != end.mt) {
for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
struct child child = stateinfo[current.mt][i];
int forced = stateinfo[current.mt][i].forced;
int illegal = stateinfo[current.mt][i].illegal;
if ((forced == 0 || (forced & flags)) &&
(illegal == 0 || !(illegal & flags)))
build_table(table, child, end, path, flags,
depth);
}
return;
}
if (flags == 0)
errx(1, "path does not set flags");
if (table[flags][0] != 0)
errx(1, "path traversed twice");
for (i = 0; i < depth; i++)
table[flags][i] = path[i].mt;
}
int
verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
int print)
{
int success = 1, i;
size_t num_valid, num_found = 0;
uint8_t flags = 0;
do {
if (table[flags][0] == 0)
continue;
num_found++;
for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
if (table[flags][i] != handshakes[flags][i]) {
fprintf(stderr,
"incorrect entry %d of handshake ", i);
fprint_flags(stderr, flags);
fprintf(stderr, "\n");
success = 0;
}
}
if (print)
fprint_entry(stdout, table[flags], flags);
} while(++flags != 0);
num_valid = count_handshakes();
if (num_valid != num_found) {
fprintf(stderr,
"incorrect number of handshakes: want %zu, got %zu.\n",
num_valid, num_found);
success = 0;
}
return success;
}
void
usage(void)
{
fprintf(stderr, "usage: handshake_table [-C | -g]\n");
exit(1);
}
int
main(int argc, char *argv[])
{
static enum tls13_message_type
hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = {
[INITIAL] = {
CLIENT_HELLO,
SERVER_HELLO_RETRY_REQUEST,
CLIENT_HELLO_RETRY,
SERVER_HELLO,
},
};
struct child start = {
.mt = CLIENT_HELLO,
};
struct child end = {
.mt = APPLICATION_DATA,
};
struct child path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
uint8_t flags = NEGOTIATED;
unsigned int depth = 0;
int ch, graphviz = 0, print = 0;
while ((ch = getopt(argc, argv, "Cg")) != -1) {
switch (ch) {
case 'C':
print = 1;
break;
case 'g':
graphviz = 1;
break;
default:
usage();
}
}
argc -= optind;
argv += optind;
if (argc != 0)
usage();
if (graphviz && print)
usage();
if (graphviz)
return generate_graphics();
build_table(hs_table, start, end, path, flags, depth);
if (!verify_table(hs_table, print))
return 1;
return 0;
}