#include <sys/types.h>
#include <sys/un.h>
#include <sys/socket.h>
#include <fcntl.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <unistd.h>
#include <errno.h>
#include "ssh.h"
#include "sshbuf.h"
#include "sshkey.h"
#include "authfd.h"
#include "log.h"
#include "misc.h"
#include "atomicio.h"
#include "ssherr.h"
#include "xmalloc.h"
#define MAX_AGENT_IDENTITIES 2048
#define MAX_AGENT_REPLY_LEN (256 * 1024)
#define agent_failed(x) \
((x == SSH_AGENT_FAILURE) || \
(x == SSH_AGENT_EXTENSION_FAILURE) || \
(x == SSH_COM_AGENT2_FAILURE) || \
(x == SSH2_AGENT_FAILURE))
static int
decode_reply(u_char type)
{
if (agent_failed(type))
return SSH_ERR_AGENT_FAILURE;
else if (type == SSH_AGENT_SUCCESS)
return 0;
else
return SSH_ERR_INVALID_FORMAT;
}
int
ssh_get_authentication_socket_path(const char *authsocket, int *fdp)
{
int sock, oerrno;
struct sockaddr_un sunaddr;
debug3_f("path '%s'", authsocket);
memset(&sunaddr, 0, sizeof(sunaddr));
sunaddr.sun_family = AF_UNIX;
strlcpy(sunaddr.sun_path, authsocket, sizeof(sunaddr.sun_path));
if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) == -1)
return SSH_ERR_SYSTEM_ERROR;
if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1 ||
connect(sock, (struct sockaddr *)&sunaddr, sizeof(sunaddr)) == -1) {
oerrno = errno;
close(sock);
errno = oerrno;
return SSH_ERR_SYSTEM_ERROR;
}
if (fdp != NULL)
*fdp = sock;
else
close(sock);
return 0;
}
int
ssh_get_authentication_socket(int *fdp)
{
const char *authsocket;
if (fdp != NULL)
*fdp = -1;
authsocket = getenv(SSH_AUTHSOCKET_ENV_NAME);
if (authsocket == NULL || *authsocket == '\0')
return SSH_ERR_AGENT_NOT_PRESENT;
return ssh_get_authentication_socket_path(authsocket, fdp);
}
static int
ssh_request_reply(int sock, struct sshbuf *request, struct sshbuf *reply)
{
int r;
size_t l, len;
char buf[1024];
len = sshbuf_len(request);
POKE_U32(buf, len);
if (atomicio(vwrite, sock, buf, 4) != 4 ||
atomicio(vwrite, sock, sshbuf_mutable_ptr(request),
sshbuf_len(request)) != sshbuf_len(request))
return SSH_ERR_AGENT_COMMUNICATION;
if (atomicio(read, sock, buf, 4) != 4)
return SSH_ERR_AGENT_COMMUNICATION;
len = PEEK_U32(buf);
if (len > MAX_AGENT_REPLY_LEN)
return SSH_ERR_INVALID_FORMAT;
sshbuf_reset(reply);
while (len > 0) {
l = len;
if (l > sizeof(buf))
l = sizeof(buf);
if (atomicio(read, sock, buf, l) != l)
return SSH_ERR_AGENT_COMMUNICATION;
if ((r = sshbuf_put(reply, buf, l)) != 0)
return r;
len -= l;
}
return 0;
}
static int
ssh_request_reply_decode(int sock, struct sshbuf *request)
{
struct sshbuf *reply;
int r;
u_char type;
if ((reply = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = ssh_request_reply(sock, request, reply)) != 0 ||
(r = sshbuf_get_u8(reply, &type)) != 0 ||
(r = decode_reply(type)) != 0)
goto out;
r = 0;
out:
sshbuf_free(reply);
return r;
}
void
ssh_close_authentication_socket(int sock)
{
if (getenv(SSH_AUTHSOCKET_ENV_NAME))
close(sock);
}
int
ssh_lock_agent(int sock, int lock, const char *password)
{
int r;
u_char type = lock ? SSH_AGENTC_LOCK : SSH_AGENTC_UNLOCK;
struct sshbuf *msg;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_u8(msg, type)) != 0 ||
(r = sshbuf_put_cstring(msg, password)) != 0 ||
(r = ssh_request_reply_decode(sock, msg)) != 0)
goto out;
r = 0;
out:
sshbuf_free(msg);
return r;
}
static int
deserialise_identity2(struct sshbuf *ids, struct sshkey **keyp, char **commentp)
{
int r;
char *comment = NULL;
const u_char *blob;
size_t blen;
if ((r = sshbuf_get_string_direct(ids, &blob, &blen)) != 0 ||
(r = sshbuf_get_cstring(ids, &comment, NULL)) != 0)
goto out;
if ((r = sshkey_from_blob(blob, blen, keyp)) != 0)
goto out;
if (commentp != NULL) {
*commentp = comment;
comment = NULL;
}
r = 0;
out:
free(comment);
return r;
}
int
ssh_fetch_identitylist(int sock, struct ssh_identitylist **idlp)
{
u_char type;
uint32_t num, i;
struct sshbuf *msg;
struct ssh_identitylist *idl = NULL;
int r;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_u8(msg, SSH2_AGENTC_REQUEST_IDENTITIES)) != 0)
goto out;
if ((r = ssh_request_reply(sock, msg, msg)) != 0)
goto out;
if ((r = sshbuf_get_u8(msg, &type)) != 0)
goto out;
if (agent_failed(type)) {
r = SSH_ERR_AGENT_FAILURE;
goto out;
} else if (type != SSH2_AGENT_IDENTITIES_ANSWER) {
r = SSH_ERR_INVALID_FORMAT;
goto out;
}
if ((r = sshbuf_get_u32(msg, &num)) != 0)
goto out;
if (num > MAX_AGENT_IDENTITIES) {
r = SSH_ERR_INVALID_FORMAT;
goto out;
}
if (num == 0) {
r = SSH_ERR_AGENT_NO_IDENTITIES;
goto out;
}
if ((idl = calloc(1, sizeof(*idl))) == NULL ||
(idl->keys = calloc(num, sizeof(*idl->keys))) == NULL ||
(idl->comments = calloc(num, sizeof(*idl->comments))) == NULL) {
r = SSH_ERR_ALLOC_FAIL;
goto out;
}
for (i = 0; i < num;) {
if ((r = deserialise_identity2(msg, &(idl->keys[i]),
&(idl->comments[i]))) != 0) {
if (r == SSH_ERR_KEY_TYPE_UNKNOWN) {
num--;
continue;
} else
goto out;
}
i++;
}
idl->nkeys = num;
*idlp = idl;
idl = NULL;
r = 0;
out:
sshbuf_free(msg);
if (idl != NULL)
ssh_free_identitylist(idl);
return r;
}
void
ssh_free_identitylist(struct ssh_identitylist *idl)
{
size_t i;
if (idl == NULL)
return;
for (i = 0; i < idl->nkeys; i++) {
if (idl->keys != NULL)
sshkey_free(idl->keys[i]);
if (idl->comments != NULL)
free(idl->comments[i]);
}
free(idl->keys);
free(idl->comments);
free(idl);
}
int
ssh_agent_has_key(int sock, const struct sshkey *key)
{
int r, ret = SSH_ERR_KEY_NOT_FOUND;
size_t i;
struct ssh_identitylist *idlist = NULL;
if ((r = ssh_fetch_identitylist(sock, &idlist)) != 0) {
return r;
}
for (i = 0; i < idlist->nkeys; i++) {
if (sshkey_equal_public(idlist->keys[i], key)) {
ret = 0;
break;
}
}
ssh_free_identitylist(idlist);
return ret;
}
static u_int
agent_encode_alg(const struct sshkey *key, const char *alg)
{
if (alg != NULL && sshkey_type_plain(key->type) == KEY_RSA) {
if (strcmp(alg, "rsa-sha2-256") == 0 ||
strcmp(alg, "rsa-sha2-256-cert-v01@openssh.com") == 0)
return SSH_AGENT_RSA_SHA2_256;
if (strcmp(alg, "rsa-sha2-512") == 0 ||
strcmp(alg, "rsa-sha2-512-cert-v01@openssh.com") == 0)
return SSH_AGENT_RSA_SHA2_512;
}
return 0;
}
int
ssh_agent_sign(int sock, const struct sshkey *key,
u_char **sigp, size_t *lenp,
const u_char *data, size_t datalen, const char *alg, u_int compat)
{
struct sshbuf *msg;
u_char *sig = NULL, type = 0;
size_t len = 0;
u_int flags = 0;
int r = SSH_ERR_INTERNAL_ERROR;
*sigp = NULL;
*lenp = 0;
if (datalen > SSH_KEY_MAX_SIGN_DATA_SIZE)
return SSH_ERR_INVALID_ARGUMENT;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
flags |= agent_encode_alg(key, alg);
if ((r = sshbuf_put_u8(msg, SSH2_AGENTC_SIGN_REQUEST)) != 0 ||
(r = sshkey_puts(key, msg)) != 0 ||
(r = sshbuf_put_string(msg, data, datalen)) != 0 ||
(r = sshbuf_put_u32(msg, flags)) != 0)
goto out;
if ((r = ssh_request_reply(sock, msg, msg)) != 0)
goto out;
if ((r = sshbuf_get_u8(msg, &type)) != 0)
goto out;
if (agent_failed(type)) {
r = SSH_ERR_AGENT_FAILURE;
goto out;
} else if (type != SSH2_AGENT_SIGN_RESPONSE) {
r = SSH_ERR_INVALID_FORMAT;
goto out;
}
if ((r = sshbuf_get_string(msg, &sig, &len)) != 0)
goto out;
if (!sshkey_is_sk(key) &&
(r = sshkey_check_sigtype(sig, len, alg)) != 0)
goto out;
*sigp = sig;
*lenp = len;
sig = NULL;
len = 0;
r = 0;
out:
freezero(sig, len);
sshbuf_free(msg);
return r;
}
static int
encode_dest_constraint_hop(struct sshbuf *m,
const struct dest_constraint_hop *dch)
{
struct sshbuf *b;
u_int i;
int r;
if ((b = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_cstring(b, dch->user)) != 0 ||
(r = sshbuf_put_cstring(b, dch->hostname)) != 0 ||
(r = sshbuf_put_string(b, NULL, 0)) != 0)
goto out;
for (i = 0; i < dch->nkeys; i++) {
if ((r = sshkey_puts(dch->keys[i], b)) != 0 ||
(r = sshbuf_put_u8(b, dch->key_is_ca[i] != 0)) != 0)
goto out;
}
if ((r = sshbuf_put_stringb(m, b)) != 0)
goto out;
r = 0;
out:
sshbuf_free(b);
return r;
}
static int
encode_dest_constraint(struct sshbuf *m, const struct dest_constraint *dc)
{
struct sshbuf *b;
int r;
if ((b = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = encode_dest_constraint_hop(b, &dc->from)) != 0 ||
(r = encode_dest_constraint_hop(b, &dc->to)) != 0 ||
(r = sshbuf_put_string(b, NULL, 0)) != 0)
goto out;
if ((r = sshbuf_put_stringb(m, b)) != 0)
goto out;
r = 0;
out:
sshbuf_free(b);
return r;
}
static int
encode_constraints(struct sshbuf *m, u_int life, u_int confirm,
const char *provider,
struct dest_constraint **dest_constraints, size_t ndest_constraints,
int cert_only, struct sshkey **certs, size_t ncerts)
{
int r;
struct sshbuf *b = NULL;
size_t i;
if (life != 0) {
if ((r = sshbuf_put_u8(m, SSH_AGENT_CONSTRAIN_LIFETIME)) != 0 ||
(r = sshbuf_put_u32(m, life)) != 0)
goto out;
}
if (confirm != 0) {
if ((r = sshbuf_put_u8(m, SSH_AGENT_CONSTRAIN_CONFIRM)) != 0)
goto out;
}
if (provider != NULL) {
if ((r = sshbuf_put_u8(m,
SSH_AGENT_CONSTRAIN_EXTENSION)) != 0 ||
(r = sshbuf_put_cstring(m,
"sk-provider@openssh.com")) != 0 ||
(r = sshbuf_put_cstring(m, provider)) != 0)
goto out;
}
if (dest_constraints != NULL && ndest_constraints > 0) {
if ((b = sshbuf_new()) == NULL) {
r = SSH_ERR_ALLOC_FAIL;
goto out;
}
for (i = 0; i < ndest_constraints; i++) {
if ((r = encode_dest_constraint(b,
dest_constraints[i])) != 0)
goto out;
}
if ((r = sshbuf_put_u8(m,
SSH_AGENT_CONSTRAIN_EXTENSION)) != 0 ||
(r = sshbuf_put_cstring(m,
"restrict-destination-v00@openssh.com")) != 0 ||
(r = sshbuf_put_stringb(m, b)) != 0)
goto out;
sshbuf_free(b);
b = NULL;
}
if (ncerts != 0) {
if ((b = sshbuf_new()) == NULL) {
r = SSH_ERR_ALLOC_FAIL;
goto out;
}
for (i = 0; i < ncerts; i++) {
if ((r = sshkey_puts(certs[i], b)) != 0)
goto out;
}
if ((r = sshbuf_put_u8(m,
SSH_AGENT_CONSTRAIN_EXTENSION)) != 0 ||
(r = sshbuf_put_cstring(m,
"associated-certs-v00@openssh.com")) != 0 ||
(r = sshbuf_put_u8(m, cert_only != 0)) != 0 ||
(r = sshbuf_put_stringb(m, b)) != 0)
goto out;
sshbuf_free(b);
b = NULL;
}
r = 0;
out:
sshbuf_free(b);
return r;
}
int
ssh_add_identity_constrained(int sock, struct sshkey *key,
const char *comment, u_int life, u_int confirm,
const char *provider, struct dest_constraint **dest_constraints,
size_t ndest_constraints)
{
struct sshbuf *msg;
int r, constrained = (life || confirm || provider || dest_constraints);
u_char type;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
switch (key->type) {
#ifdef WITH_OPENSSL
case KEY_RSA:
case KEY_RSA_CERT:
case KEY_ECDSA:
case KEY_ECDSA_CERT:
case KEY_ECDSA_SK:
case KEY_ECDSA_SK_CERT:
#endif
case KEY_ED25519:
case KEY_ED25519_CERT:
case KEY_ED25519_SK:
case KEY_ED25519_SK_CERT:
type = constrained ?
SSH2_AGENTC_ADD_ID_CONSTRAINED :
SSH2_AGENTC_ADD_IDENTITY;
if ((r = sshbuf_put_u8(msg, type)) != 0 ||
(r = sshkey_private_serialize(key, msg)) != 0 ||
(r = sshbuf_put_cstring(msg, comment)) != 0)
goto out;
break;
default:
r = SSH_ERR_INVALID_ARGUMENT;
goto out;
}
if (constrained &&
(r = encode_constraints(msg, life, confirm, provider,
dest_constraints, ndest_constraints, 0, NULL, 0)) != 0)
goto out;
if ((r = ssh_request_reply_decode(sock, msg)) != 0)
goto out;
r = 0;
out:
sshbuf_free(msg);
return r;
}
int
ssh_remove_identity(int sock, const struct sshkey *key)
{
struct sshbuf *msg;
int r;
u_char *blob = NULL;
size_t blen;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if (key->type != KEY_UNSPEC) {
if ((r = sshkey_to_blob(key, &blob, &blen)) != 0)
goto out;
if ((r = sshbuf_put_u8(msg,
SSH2_AGENTC_REMOVE_IDENTITY)) != 0 ||
(r = sshbuf_put_string(msg, blob, blen)) != 0)
goto out;
} else {
r = SSH_ERR_INVALID_ARGUMENT;
goto out;
}
if ((r = ssh_request_reply_decode(sock, msg)) != 0)
goto out;
r = 0;
out:
if (blob != NULL)
freezero(blob, blen);
sshbuf_free(msg);
return r;
}
int
ssh_update_card(int sock, int add, const char *reader_id, const char *pin,
u_int life, u_int confirm,
struct dest_constraint **dest_constraints, size_t ndest_constraints,
int cert_only, struct sshkey **certs, size_t ncerts)
{
struct sshbuf *msg;
int r, constrained = (life || confirm || dest_constraints || certs);
u_char type;
if (add) {
type = constrained ?
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED :
SSH_AGENTC_ADD_SMARTCARD_KEY;
} else
type = SSH_AGENTC_REMOVE_SMARTCARD_KEY;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_u8(msg, type)) != 0 ||
(r = sshbuf_put_cstring(msg, reader_id)) != 0 ||
(r = sshbuf_put_cstring(msg, pin)) != 0)
goto out;
if (constrained &&
(r = encode_constraints(msg, life, confirm, NULL,
dest_constraints, ndest_constraints,
cert_only, certs, ncerts)) != 0)
goto out;
if ((r = ssh_request_reply_decode(sock, msg)) != 0)
goto out;
r = 0;
out:
sshbuf_free(msg);
return r;
}
int
ssh_remove_all_identities(int sock, int version)
{
struct sshbuf *msg;
u_char type = (version == 1) ?
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES :
SSH2_AGENTC_REMOVE_ALL_IDENTITIES;
int r;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_u8(msg, type)) != 0)
goto out;
if ((r = ssh_request_reply_decode(sock, msg)) != 0)
goto out;
r = 0;
out:
sshbuf_free(msg);
return r;
}
int
ssh_agent_bind_hostkey(int sock, const struct sshkey *key,
const struct sshbuf *session_id, const struct sshbuf *signature,
int forwarding)
{
struct sshbuf *msg;
int r;
if (key == NULL || session_id == NULL || signature == NULL)
return SSH_ERR_INVALID_ARGUMENT;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_u8(msg, SSH_AGENTC_EXTENSION)) != 0 ||
(r = sshbuf_put_cstring(msg, "session-bind@openssh.com")) != 0 ||
(r = sshkey_puts(key, msg)) != 0 ||
(r = sshbuf_put_stringb(msg, session_id)) != 0 ||
(r = sshbuf_put_stringb(msg, signature)) != 0 ||
(r = sshbuf_put_u8(msg, forwarding ? 1 : 0)) != 0)
goto out;
if ((r = ssh_request_reply_decode(sock, msg)) != 0)
goto out;
r = 0;
out:
sshbuf_free(msg);
return r;
}
int
ssh_agent_query_extensions(int sock, char ***exts)
{
struct sshbuf *msg;
int r;
u_char type;
char *cp = NULL, **ret = NULL;
size_t i = 0;
*exts = NULL;
if ((msg = sshbuf_new()) == NULL)
return SSH_ERR_ALLOC_FAIL;
if ((r = sshbuf_put_u8(msg, SSH_AGENTC_EXTENSION)) != 0 ||
(r = sshbuf_put_cstring(msg, "query")) != 0)
goto out;
if ((r = ssh_request_reply(sock, msg, msg)) != 0)
goto out;
if ((r = sshbuf_get_u8(msg, &type)) != 0)
goto out;
if (agent_failed(type)) {
r = SSH_ERR_AGENT_FAILURE;
goto out;
}
if (type != SSH_AGENT_EXTENSION_RESPONSE ||
(r = sshbuf_get_cstring(msg, &cp, NULL)) != 0 ||
strcmp(cp, "query") != 0) {
r = SSH_ERR_INVALID_FORMAT;
goto out;
}
ret = calloc(1, sizeof(*ret));
while (sshbuf_len(msg)) {
ret = xrecallocarray(ret, i + 1, i + 2, sizeof(*ret));
if ((r = sshbuf_get_cstring(msg, ret + i, NULL)) != 0) {
r = SSH_ERR_INVALID_FORMAT;
goto out;
}
i++;
}
r = 0;
*exts = ret;
ret = NULL;
out:
free(cp);
stringlist_free(ret);
sshbuf_free(msg);
return r;
}