#include <sm/gen.h>
SM_RCSID("@(#)$Id: sfsasl.c,v 8.118 2008/07/22 15:12:48 ca Exp $")
#include <stdlib.h>
#include <sendmail.h>
#include <sm/time.h>
#include <errno.h>
#ifndef DEAL_WITH_ERROR_SSL
# define DEAL_WITH_ERROR_SSL 1
#endif
#if SASL
# include "sfsasl.h"
struct sasl_obj
{
SM_FILE_T *fp;
sasl_conn_t *conn;
};
struct sasl_info
{
SM_FILE_T *fp;
sasl_conn_t *conn;
};
static int sasl_getinfo __P((SM_FILE_T *, int, void *));
static int
sasl_getinfo(fp, what, valp)
SM_FILE_T *fp;
int what;
void *valp;
{
struct sasl_obj *so = (struct sasl_obj *) fp->f_cookie;
switch (what)
{
case SM_IO_WHAT_FD:
if (so->fp == NULL)
return -1;
return so->fp->f_file;
case SM_IO_IS_READABLE:
if (so->fp == NULL)
return 0;
return sm_io_getinfo(so->fp, what, valp);
default:
return -1;
}
}
static int sasl_open __P((SM_FILE_T *, const void *, int, const void *));
static int
sasl_open(fp, info, flags, rpool)
SM_FILE_T *fp;
const void *info;
int flags;
const void *rpool;
{
struct sasl_obj *so;
struct sasl_info *si = (struct sasl_info *) info;
so = (struct sasl_obj *) sm_malloc(sizeof(struct sasl_obj));
if (so == NULL)
{
errno = ENOMEM;
return -1;
}
so->fp = si->fp;
so->conn = si->conn;
(void) sm_io_setvbuf(so->fp, SM_TIME_DEFAULT, NULL, SM_IO_NOW, 0);
fp->f_cookie = so;
return 0;
}
static int sasl_close __P((SM_FILE_T *));
static int
sasl_close(fp)
SM_FILE_T *fp;
{
struct sasl_obj *so;
so = (struct sasl_obj *) fp->f_cookie;
if (so == NULL)
return 0;
if (so->fp != NULL)
{
sm_io_close(so->fp, SM_TIME_DEFAULT);
so->fp = NULL;
}
sm_free(so);
so = NULL;
return 0;
}
extern void sm_sasl_free __P((void *));
# define SASL_DEALLOC(b) sm_sasl_free(b)
static ssize_t sasl_read __P((SM_FILE_T *, char *, size_t));
static ssize_t
sasl_read(fp, buf, size)
SM_FILE_T *fp;
char *buf;
size_t size;
{
int result;
ssize_t len;
# if SASL >= 20000
static const char *outbuf = NULL;
# else
static char *outbuf = NULL;
# endif
static unsigned int outlen = 0;
static unsigned int offset = 0;
struct sasl_obj *so = (struct sasl_obj *) fp->f_cookie;
# if SASL >= 20000
while (outlen == 0)
# else
while (outbuf == NULL && outlen == 0)
# endif
{
len = sm_io_read(so->fp, SM_TIME_DEFAULT, buf, size);
if (len <= 0)
return len;
result = sasl_decode(so->conn, buf,
(unsigned int) len, &outbuf, &outlen);
if (result != SASL_OK)
{
if (LogLevel > 7)
sm_syslog(LOG_WARNING, NOQID,
"AUTH: sasl_decode error=%d", result);
outbuf = NULL;
offset = 0;
outlen = 0;
return -1;
}
}
if (outbuf == NULL)
{
syserr("@sasl_read failure: outbuf == NULL but outlen != 0");
}
if (outlen - offset > size)
{
(void) memcpy(buf, outbuf + offset, size);
offset += size;
len = size;
}
else
{
len = outlen - offset;
(void) memcpy(buf, outbuf + offset, (size_t) len);
# if SASL < 20000
SASL_DEALLOC(outbuf);
# endif
outbuf = NULL;
offset = 0;
outlen = 0;
}
return len;
}
static ssize_t sasl_write __P((SM_FILE_T *, const char *, size_t));
static ssize_t
sasl_write(fp, buf, size)
SM_FILE_T *fp;
const char *buf;
size_t size;
{
int result;
# if SASL >= 20000
const char *outbuf;
# else
char *outbuf;
# endif
unsigned int outlen, *maxencode;
size_t ret = 0, total = 0;
struct sasl_obj *so = (struct sasl_obj *) fp->f_cookie;
result = sasl_getprop(so->conn, SASL_MAXOUTBUF,
(const void **) &maxencode);
if (result == SASL_OK && size > *maxencode && *maxencode > 0)
size = *maxencode;
result = sasl_encode(so->conn, buf,
(unsigned int) size, &outbuf, &outlen);
if (result != SASL_OK)
{
if (LogLevel > 7)
sm_syslog(LOG_WARNING, NOQID,
"AUTH: sasl_encode error=%d", result);
return -1;
}
if (outbuf != NULL)
{
while (outlen > 0)
{
errno = 0;
ret = sm_io_write(so->fp, SM_TIME_DEFAULT,
&outbuf[total], outlen);
if (ret <= 0)
return ret;
outlen -= ret;
total += ret;
}
# if SASL < 20000
SASL_DEALLOC(outbuf);
# endif
}
return size;
}
int
sfdcsasl(fin, fout, conn, tmo)
SM_FILE_T **fin;
SM_FILE_T **fout;
sasl_conn_t *conn;
int tmo;
{
SM_FILE_T *newin, *newout;
SM_FILE_T SM_IO_SET_TYPE(sasl_vector, "sasl", sasl_open, sasl_close,
sasl_read, sasl_write, NULL, sasl_getinfo, NULL,
SM_TIME_DEFAULT);
struct sasl_info info;
if (conn == NULL)
{
return 0;
}
SM_IO_INIT_TYPE(sasl_vector, "sasl", sasl_open, sasl_close,
sasl_read, sasl_write, NULL, sasl_getinfo, NULL,
SM_TIME_DEFAULT);
info.fp = *fin;
info.conn = conn;
newin = sm_io_open(&sasl_vector, SM_TIME_DEFAULT, &info,
SM_IO_RDONLY_B, NULL);
if (newin == NULL)
return -1;
info.fp = *fout;
info.conn = conn;
newout = sm_io_open(&sasl_vector, SM_TIME_DEFAULT, &info,
SM_IO_WRONLY_B, NULL);
if (newout == NULL)
{
(void) sm_io_close(newin, SM_TIME_DEFAULT);
return -1;
}
sm_io_automode(newin, newout);
sm_io_setinfo(*fin, SM_IO_WHAT_TIMEOUT, &tmo);
sm_io_setinfo(*fout, SM_IO_WHAT_TIMEOUT, &tmo);
*fin = newin;
*fout = newout;
return 0;
}
#endif
#if STARTTLS
# include "sfsasl.h"
# include <openssl/err.h>
struct tls_obj
{
SM_FILE_T *fp;
SSL *con;
};
struct tls_info
{
SM_FILE_T *fp;
SSL *con;
};
static int tls_getinfo __P((SM_FILE_T *, int, void *));
static int
tls_getinfo(fp, what, valp)
SM_FILE_T *fp;
int what;
void *valp;
{
struct tls_obj *so = (struct tls_obj *) fp->f_cookie;
switch (what)
{
case SM_IO_WHAT_FD:
if (so->fp == NULL)
return -1;
return so->fp->f_file;
case SM_IO_IS_READABLE:
return SSL_pending(so->con) > 0;
default:
return -1;
}
}
static int tls_open __P((SM_FILE_T *, const void *, int, const void *));
static int
tls_open(fp, info, flags, rpool)
SM_FILE_T *fp;
const void *info;
int flags;
const void *rpool;
{
struct tls_obj *so;
struct tls_info *ti = (struct tls_info *) info;
so = (struct tls_obj *) sm_malloc(sizeof(struct tls_obj));
if (so == NULL)
{
errno = ENOMEM;
return -1;
}
so->fp = ti->fp;
so->con = ti->con;
fp->f_file = sm_io_getinfo(so->fp, SM_IO_WHAT_FD, NULL);
(void) sm_io_setvbuf(so->fp, SM_TIME_DEFAULT, NULL, SM_IO_NOW, 0);
fp->f_cookie = so;
return 0;
}
static int tls_close __P((SM_FILE_T *));
static int
tls_close(fp)
SM_FILE_T *fp;
{
struct tls_obj *so;
so = (struct tls_obj *) fp->f_cookie;
if (so == NULL)
return 0;
if (so->fp != NULL)
{
sm_io_close(so->fp, SM_TIME_DEFAULT);
so->fp = NULL;
}
sm_free(so);
so = NULL;
return 0;
}
# define MAX_TLS_IOS 4
int
tls_retry(ssl, rfd, wfd, tlsstart, timeout, err, where)
SSL *ssl;
int rfd;
int wfd;
time_t tlsstart;
int timeout;
int err;
const char *where;
{
int ret;
time_t left;
time_t now = curtime();
struct timeval tv;
ret = -1;
left = timeout - (now - tlsstart);
if (left <= 0)
return 0;
tv.tv_sec = left;
tv.tv_usec = 0;
if (LogLevel > 14)
{
sm_syslog(LOG_INFO, NOQID,
"STARTTLS=%s, info: fds=%d/%d, err=%d",
where, rfd, wfd, err);
}
if (FD_SETSIZE > 0 &&
((err == SSL_ERROR_WANT_READ && rfd >= FD_SETSIZE) ||
(err == SSL_ERROR_WANT_WRITE && wfd >= FD_SETSIZE)))
{
if (LogLevel > 5)
{
sm_syslog(LOG_ERR, NOQID,
"STARTTLS=%s, error: fd %d/%d too large",
where, rfd, wfd);
if (LogLevel > 8)
tlslogerr(where);
}
errno = EINVAL;
}
else if (err == SSL_ERROR_WANT_READ)
{
fd_set ssl_maskr, ssl_maskx;
FD_ZERO(&ssl_maskr);
FD_SET(rfd, &ssl_maskr);
FD_ZERO(&ssl_maskx);
FD_SET(rfd, &ssl_maskx);
do
{
ret = select(rfd + 1, &ssl_maskr, NULL, &ssl_maskx,
&tv);
} while (ret < 0 && errno == EINTR);
if (ret < 0 && errno > 0)
ret = -errno;
}
else if (err == SSL_ERROR_WANT_WRITE)
{
fd_set ssl_maskw, ssl_maskx;
FD_ZERO(&ssl_maskw);
FD_SET(wfd, &ssl_maskw);
FD_ZERO(&ssl_maskx);
FD_SET(rfd, &ssl_maskx);
do
{
ret = select(wfd + 1, NULL, &ssl_maskw, &ssl_maskx,
&tv);
} while (ret < 0 && errno == EINTR);
if (ret < 0 && errno > 0)
ret = -errno;
}
return ret;
}
#ifdef ETIMEDOUT
# define SM_ERR_TIMEOUT ETIMEDOUT
#else
# define SM_ERR_TIMEOUT EIO
#endif
static int tls_rd_tmo = -1;
void
set_tls_rd_tmo(rd_tmo)
int rd_tmo;
{
tls_rd_tmo = rd_tmo;
}
static ssize_t tls_read __P((SM_FILE_T *, char *, size_t));
static ssize_t
tls_read(fp, buf, size)
SM_FILE_T *fp;
char *buf;
size_t size;
{
int r, rfd, wfd, try, ssl_err;
struct tls_obj *so = (struct tls_obj *) fp->f_cookie;
time_t tlsstart;
char *err;
try = 99;
err = NULL;
tlsstart = curtime();
retry:
r = SSL_read(so->con, (char *) buf, size);
if (r > 0)
return r;
err = NULL;
switch (ssl_err = SSL_get_error(so->con, r))
{
case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_WRITE:
err = "read W BLOCK";
case SSL_ERROR_WANT_READ:
if (err == NULL)
err = "read R BLOCK";
rfd = SSL_get_rfd(so->con);
wfd = SSL_get_wfd(so->con);
try = tls_retry(so->con, rfd, wfd, tlsstart,
(tls_rd_tmo < 0) ? TimeOuts.to_datablock
: tls_rd_tmo,
ssl_err, "read");
if (try > 0)
goto retry;
errno = SM_ERR_TIMEOUT;
break;
case SSL_ERROR_WANT_X509_LOOKUP:
err = "write X BLOCK";
break;
case SSL_ERROR_SYSCALL:
if (r == 0 && errno == 0)
break;
err = "syscall error";
break;
case SSL_ERROR_SSL:
#if DEAL_WITH_ERROR_SSL
if (r == 0 && errno == 0)
break;
#endif
err = "generic SSL error";
if (LogLevel > 9)
tlslogerr("read");
#if DEAL_WITH_ERROR_SSL
if (r == 0)
r = -1;
#endif
break;
}
if (err != NULL)
{
int save_errno;
save_errno = (errno == 0) ? EIO : errno;
if (try == 0 && save_errno == SM_ERR_TIMEOUT)
{
if (LogLevel > 7)
sm_syslog(LOG_WARNING, NOQID,
"STARTTLS: read error=timeout");
}
else if (LogLevel > 8)
sm_syslog(LOG_WARNING, NOQID,
"STARTTLS: read error=%s (%d), errno=%d, get_error=%s, retry=%d, ssl_err=%d",
err, r, errno,
ERR_error_string(ERR_get_error(), NULL), try,
ssl_err);
else if (LogLevel > 7)
sm_syslog(LOG_WARNING, NOQID,
"STARTTLS: read error=%s (%d), retry=%d, ssl_err=%d",
err, r, errno, try, ssl_err);
errno = save_errno;
}
return r;
}
static ssize_t tls_write __P((SM_FILE_T *, const char *, size_t));
static ssize_t
tls_write(fp, buf, size)
SM_FILE_T *fp;
const char *buf;
size_t size;
{
int r, rfd, wfd, try, ssl_err;
struct tls_obj *so = (struct tls_obj *) fp->f_cookie;
time_t tlsstart;
char *err;
try = 99;
err = NULL;
tlsstart = curtime();
retry:
r = SSL_write(so->con, (char *) buf, size);
if (r > 0)
return r;
err = NULL;
switch (ssl_err = SSL_get_error(so->con, r))
{
case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_WRITE:
err = "read W BLOCK";
case SSL_ERROR_WANT_READ:
if (err == NULL)
err = "read R BLOCK";
rfd = SSL_get_rfd(so->con);
wfd = SSL_get_wfd(so->con);
try = tls_retry(so->con, rfd, wfd, tlsstart,
DATA_PROGRESS_TIMEOUT, ssl_err, "write");
if (try > 0)
goto retry;
errno = SM_ERR_TIMEOUT;
break;
case SSL_ERROR_WANT_X509_LOOKUP:
err = "write X BLOCK";
break;
case SSL_ERROR_SYSCALL:
if (r == 0 && errno == 0)
break;
err = "syscall error";
break;
case SSL_ERROR_SSL:
err = "generic SSL error";
if (LogLevel > 9)
tlslogerr("write");
#if DEAL_WITH_ERROR_SSL
if (r == 0)
r = -1;
#endif
break;
}
if (err != NULL)
{
int save_errno;
save_errno = (errno == 0) ? EIO : errno;
if (try == 0 && save_errno == SM_ERR_TIMEOUT)
{
if (LogLevel > 7)
sm_syslog(LOG_WARNING, NOQID,
"STARTTLS: write error=timeout");
}
else if (LogLevel > 8)
sm_syslog(LOG_WARNING, NOQID,
"STARTTLS: write error=%s (%d), errno=%d, get_error=%s, retry=%d, ssl_err=%d",
err, r, errno,
ERR_error_string(ERR_get_error(), NULL), try,
ssl_err);
else if (LogLevel > 7)
sm_syslog(LOG_WARNING, NOQID,
"STARTTLS: write error=%s (%d), errno=%d, retry=%d, ssl_err=%d",
err, r, errno, try, ssl_err);
errno = save_errno;
}
return r;
}
int
sfdctls(fin, fout, con)
SM_FILE_T **fin;
SM_FILE_T **fout;
SSL *con;
{
SM_FILE_T *tlsin, *tlsout;
SM_FILE_T SM_IO_SET_TYPE(tls_vector, "tls", tls_open, tls_close,
tls_read, tls_write, NULL, tls_getinfo, NULL,
SM_TIME_FOREVER);
struct tls_info info;
SM_ASSERT(con != NULL);
SM_IO_INIT_TYPE(tls_vector, "tls", tls_open, tls_close,
tls_read, tls_write, NULL, tls_getinfo, NULL,
SM_TIME_FOREVER);
info.fp = *fin;
info.con = con;
tlsin = sm_io_open(&tls_vector, SM_TIME_DEFAULT, &info, SM_IO_RDONLY_B,
NULL);
if (tlsin == NULL)
return -1;
info.fp = *fout;
tlsout = sm_io_open(&tls_vector, SM_TIME_DEFAULT, &info, SM_IO_WRONLY_B,
NULL);
if (tlsout == NULL)
{
(void) sm_io_close(tlsin, SM_TIME_DEFAULT);
return -1;
}
sm_io_automode(tlsin, tlsout);
*fin = tlsin;
*fout = tlsout;
return 0;
}
#endif