root/sbin/ipf/ipfsync/ipfsyncd.c
/*
 * Copyright (C) 2012 by Darren Reed.
 *
 * See the IPFILTER.LICENCE file for details on licencing.
 */
#include <sys/types.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/sockio.h>
#include <sys/errno.h>

#include <netinet/in.h>
#include <net/if.h>

#include <arpa/inet.h>

#include <stdio.h>
#include <stdlib.h>
#include <fcntl.h>
#include <unistd.h>
#include <string.h>
#include <syslog.h>
#include <signal.h>

#include "ipf.h"
#include "opts.h"


#define R_IO_ERROR      -1
#define R_OKAY          0
#define R_MORE          1
#define R_SKIP          2
#if     defined(sun) && !defined(SOLARIS2)
# define        STRERROR(x)     sys_errlist[x]
extern  char    *sys_errlist[];
#else
# define        STRERROR(x)     strerror(x)
#endif


int     main(int, char *[]);
void    usage(char *);
void    printsynchdr(synchdr_t *);
void    printtable(int);
void    printsmcproto(char *);
void    printcommand(int);
int     do_kbuff(int, char *, int *);
int     do_packet(int, char *);
int     buildsocket(char *, struct sockaddr_in *);
void    do_io(void);
void    handleterm(int);

int     terminate = 0;
int     igmpfd = -1;
int     nfd = -1;
int     lfd = -1;
int     opts = 0;

void
usage(progname)
        char *progname;
{
        fprintf(stderr,
                "Usage: %s [-d] [-p port] [-i address] -I <interface>\n",
                progname);
}

void
handleterm(sig)
        int sig;
{
        terminate = sig;
}


/* should be large enough to hold header + any datatype */
#define BUFFERLEN 1400

int
main(argc, argv)
        int argc;
        char *argv[];
{
        struct sockaddr_in sin;
        char *interface;
        char *progname;
        int opt, tries;

        progname = strrchr(argv[0], '/');
        if (progname) {
                progname++;
        } else {
                progname = argv[0];
        }

        opts = 0;
        tries = 0;
        interface = NULL;

        bzero((char *)&sin, sizeof(sin));
        sin.sin_family = AF_INET;
        sin.sin_port = htons(0xaf6c);
        sin.sin_addr.s_addr = htonl(INADDR_UNSPEC_GROUP | 0x697066);

        while ((opt = getopt(argc, argv, "di:I:p:")) != -1)
                switch (opt)
                {
                case 'd' :
                        debuglevel++;
                        break;
                case 'I' :
                        interface = optarg;
                        break;
                case 'i' :
                        sin.sin_addr.s_addr = inet_addr(optarg);
                        break;
                case 'p' :
                        sin.sin_port = htons(atoi(optarg));
                        break;
                }

        if (interface == NULL) {
                usage(progname);
                exit(1);
        }

        if (!debuglevel) {

#ifdef BSD
                daemon(0, 0);
#else
                int fd = open("/dev/null", O_RDWR);

                switch (fork())
                {
                case 0 :
                        break;

                case -1 :
                        fprintf(stderr, "%s: fork() failed: %s\n",
                                argv[0], STRERROR(errno));
                        exit(1);
                        /* NOTREACHED */

                default :
                        exit(0);
                        /* NOTREACHED */
                }

                dup2(fd, 0);
                dup2(fd, 1);
                dup2(fd, 2);
                close(fd);

                setsid();
#endif
        }

        signal(SIGHUP, handleterm);
        signal(SIGINT, handleterm);
        signal(SIGTERM, handleterm);

        openlog(progname, LOG_PID, LOG_SECURITY);

        while (!terminate) {
                if (lfd != -1) {
                        close(lfd);
                        lfd = -1;
                }
                if (nfd != -1) {
                        close(nfd);
                        nfd = -1;
                }
                if (igmpfd != -1) {
                        close(igmpfd);
                        igmpfd = -1;
                }

                if (buildsocket(interface, &sin) == -1)
                        goto tryagain;

                lfd = open(IPSYNC_NAME, O_RDWR);
                if (lfd == -1) {
                        syslog(LOG_ERR, "open(%s):%m", IPSYNC_NAME);
                        debug(1, "open(%s): %s\n", IPSYNC_NAME,
                              STRERROR(errno));
                        goto tryagain;
                }

                tries = -1;
                do_io();
tryagain:
                tries++;
                syslog(LOG_INFO, "retry in %d seconds", 1 << tries);
                debug(1, "wait %d seconds\n", 1 << tries);
                sleep(1 << tries);
        }


        /* terminate */
        if (lfd != -1)
                close(lfd);
        if (nfd != -1)
                close(nfd);

        syslog(LOG_ERR, "signal %d received, exiting...", terminate);
        debug(1, "signal %d received, exiting...", terminate);

        exit(1);
}


void
do_io()
{
        char nbuff[BUFFERLEN];
        char buff[BUFFERLEN];
        fd_set mrd, rd;
        int maxfd;
        int inbuf;
        int n1;
        int left;

        FD_ZERO(&mrd);
        FD_SET(lfd, &mrd);
        FD_SET(nfd, &mrd);
        maxfd = nfd;
        if (lfd > maxfd)
                maxfd = lfd;
        debug(2, "nfd %d lfd %d maxfd %d\n", nfd, lfd, maxfd);

        inbuf = 0;
        /*
         * A threaded approach to this loop would have one thread
         * work on reading lfd (only) all the time and another thread
         * working on reading nfd all the time.
         */
        while (!terminate) {
                int n;

                rd = mrd;

                n = select(maxfd + 1, &rd, NULL, NULL, NULL);
                if (n < 0) {
                        switch (errno)
                        {
                        case EINTR :
                                continue;
                        default :
                                syslog(LOG_ERR, "select error: %m");
                                debug(1, "select error: %s\n", STRERROR(errno));
                                return;
                        }
                }

                if (FD_ISSET(lfd, &rd)) {
                        n1 = read(lfd, buff+inbuf, BUFFERLEN-inbuf);

                        debug(3, "read(K):%d\n", n1);

                        if (n1 <= 0) {
                                syslog(LOG_ERR, "read error (k-header): %m");
                                debug(1, "read error (k-header): %s\n",
                                      STRERROR(errno));
                                return;
                        }

                        left = 0;

                        switch (do_kbuff(n1, buff, &left))
                        {
                        case R_IO_ERROR :
                                return;
                        case R_MORE :
                                inbuf += left;
                                break;
                        default :
                                inbuf = 0;
                                break;
                        }
                }

                if (FD_ISSET(nfd, &rd)) {
                        n1 = recv(nfd, nbuff, sizeof(nbuff), 0);

                        debug(3, "read(N):%d\n", n1);

                        if (n1 <= 0) {
                                syslog(LOG_ERR, "read error (n-header): %m");
                                debug(1, "read error (n-header): %s\n",
                                      STRERROR(errno));
                                return;
                        }

                        switch (do_packet(n1, nbuff))
                        {
                        case R_IO_ERROR :
                                return;
                        default :
                                break;
                        }
                }
        }
}


int
buildsocket(nicname, sinp)
        char *nicname;
        struct sockaddr_in *sinp;
{
        struct sockaddr_in *reqip;
        struct ifreq req;
        char opt;

        debug(2, "binding to %s:%s\n", nicname, inet_ntoa(sinp->sin_addr));

        if (IN_MULTICAST(ntohl(sinp->sin_addr.s_addr))) {
                struct in_addr addr;
                struct ip_mreq mreq;

                igmpfd = socket(AF_INET, SOCK_RAW, IPPROTO_IGMP);
                if (igmpfd == -1) {
                        syslog(LOG_ERR, "socket:%m");
                        debug(1, "socket:%s\n", STRERROR(errno));
                        return -1;
                }

                bzero((char *)&req, sizeof(req));
                strncpy(req.ifr_name, nicname, sizeof(req.ifr_name));
                req.ifr_name[sizeof(req.ifr_name) - 1] = '\0';
                if (ioctl(igmpfd, SIOCGIFADDR, &req) == -1) {
                        syslog(LOG_ERR, "ioctl(SIOCGIFADDR):%m");
                        debug(1, "ioctl(SIOCGIFADDR):%s\n", STRERROR(errno));
                        close(igmpfd);
                        igmpfd = -1;
                        return -1;
                }
                reqip = (struct sockaddr_in *)&req.ifr_addr;

                addr = reqip->sin_addr;
                if (setsockopt(igmpfd, IPPROTO_IP, IP_MULTICAST_IF,
                               (char *)&addr, sizeof(addr)) == -1) {
                        syslog(LOG_ERR, "setsockopt(IP_MULTICAST_IF(%s)):%m",
                               inet_ntoa(addr));
                        debug(1, "setsockopt(IP_MULTICAST_IF(%s)):%s\n",
                              inet_ntoa(addr), STRERROR(errno));
                        close(igmpfd);
                        igmpfd = -1;
                        return -1;
                }

                opt = 0;
                if (setsockopt(igmpfd, IPPROTO_IP, IP_MULTICAST_LOOP,
                               (char *)&opt, sizeof(opt)) == -1) {
                        syslog(LOG_ERR, "setsockopt(IP_MULTICAST_LOOP=0):%m");
                        debug(1, "setsockopt(IP_MULTICAST_LOOP=0):%s\n",
                              STRERROR(errno));
                        close(igmpfd);
                        igmpfd = -1;
                        return -1;
                }

                opt = 63;
                if (setsockopt(igmpfd, IPPROTO_IP, IP_MULTICAST_TTL,
                               (char *)&opt, sizeof(opt)) == -1) {
                        syslog(LOG_ERR, "setsockopt(IP_MULTICAST_TTL=%d):%m",
                               opt);
                        debug(1, "setsockopt(IP_MULTICAST_TTL=%d):%s\n", opt,
                              STRERROR(errno));
                        close(igmpfd);
                        igmpfd = -1;
                        return -1;
                }

                mreq.imr_multiaddr.s_addr = sinp->sin_addr.s_addr;
                mreq.imr_interface.s_addr = reqip->sin_addr.s_addr;

                if (setsockopt(igmpfd, IPPROTO_IP, IP_ADD_MEMBERSHIP,
                               (char *)&mreq, sizeof(mreq)) == -1) {
                        char buffer[80];

                        snprintf(buffer, sizeof(buffer), "%s,", inet_ntoa(sinp->sin_addr));
                        strcat(buffer, inet_ntoa(reqip->sin_addr));

                        syslog(LOG_ERR,
                               "setsockpt(IP_ADD_MEMBERSHIP,%s):%m", buffer);
                        debug(1, "setsockpt(IP_ADD_MEMBERSHIP,%s):%s\n",
                              buffer, STRERROR(errno));
                        close(igmpfd);
                        igmpfd = -1;
                        return -1;
                }
        }
        nfd = socket(AF_INET, SOCK_DGRAM, 0);
        if (nfd == -1) {
                syslog(LOG_ERR, "socket:%m");
                if (igmpfd != -1) {
                        close(igmpfd);
                        igmpfd = -1;
                }
                return -1;
        }
        bzero((char *)&req, sizeof(req));
        strncpy(req.ifr_name, nicname, sizeof(req.ifr_name));
        req.ifr_name[sizeof(req.ifr_name) - 1] = '\0';
        if (ioctl(nfd, SIOCGIFADDR, &req) == -1) {
                syslog(LOG_ERR, "ioctl(SIOCGIFADDR):%m");
                debug(1, "ioctl(SIOCGIFADDR):%s\n", STRERROR(errno));
                close(igmpfd);
                igmpfd = -1;
                return -1;
        }

        if (bind(nfd, (struct sockaddr *)&req.ifr_addr,
                 sizeof(req.ifr_addr)) == -1) {
                syslog(LOG_ERR, "bind:%m");
                debug(1, "bind:%s\n", STRERROR(errno));
                close(nfd);
                if (igmpfd != -1) {
                        close(igmpfd);
                        igmpfd = -1;
                }
                nfd = -1;
                return -1;
        }

        if (connect(nfd, (struct sockaddr *)sinp, sizeof(*sinp)) == -1) {
                syslog(LOG_ERR, "connect:%m");
                debug(1, "connect:%s\n", STRERROR(errno));
                close(nfd);
                if (igmpfd != -1) {
                        close(igmpfd);
                        igmpfd = -1;
                }
                nfd = -1;
                return -1;
        }
        syslog(LOG_INFO, "Sending data to %s", inet_ntoa(sinp->sin_addr));
        debug(3, "Sending data to %s\n", inet_ntoa(sinp->sin_addr));

        return nfd;
}


int
do_packet(pklen, buff)
        int pklen;
        char *buff;
{
        synchdr_t *sh;
        u_32_t magic;
        int len;
        int n2;
        int n3;

        while (pklen > 0) {
                if (pklen < sizeof(*sh)) {
                        syslog(LOG_ERR, "packet length too short:%d", pklen);
                        debug(2, "packet length too short:%d\n", pklen);
                        return R_SKIP;
                }

                sh = (synchdr_t *)buff;
                len = ntohl(sh->sm_len);
                magic = ntohl(sh->sm_magic);

                if (magic != SYNHDRMAGIC) {
                        syslog(LOG_ERR, "invalid header magic %x", magic);
                        debug(2, "invalid header magic %x\n", magic);
                        return R_SKIP;
                }

                if (pklen < len + sizeof(*sh)) {
                        syslog(LOG_ERR, "packet length too short:%d", pklen);
                        debug(2, "packet length too short:%d\n", pklen);
                        return R_SKIP;
                }

                if (debuglevel > 3) {
                        printsynchdr(sh);
                        printcommand(sh->sm_cmd);
                        printtable(sh->sm_table);
                        printsmcproto(buff);
                }

                n2 = sizeof(*sh) + len;

                do {
                        n3 = write(lfd, buff, n2);
                        if (n3 <= 0) {
                                syslog(LOG_ERR, "write error: %m");
                                debug(1, "write error: %s\n", STRERROR(errno));
                                return R_IO_ERROR;
                        }

                        n2 -= n3;
                        buff += n3;
                        pklen -= n3;
                } while (n3 != 0);
        }

        return R_OKAY;
}



int
do_kbuff(inbuf, buf, left)
        int inbuf, *left;
        char *buf;
{
        synchdr_t *sh;
        u_32_t magic;
        int complete;
        int sendlen;
        int error;
        int bytes;
        int len;
        int n2;
        int n3;

        sendlen = 0;
        bytes = inbuf;
        error = R_OKAY;
        sh = (synchdr_t *)buf;

        for (complete = 0; bytes > 0; complete++) {
                len = ntohl(sh->sm_len);
                magic = ntohl(sh->sm_magic);

                if (magic != SYNHDRMAGIC) {
                        syslog(LOG_ERR,
                               "read invalid header magic 0x%x, flushing",
                               magic);
                        debug(2, "read invalid header magic 0x%x, flushing\n",
                               magic);
                        n2 = SMC_RLOG;
                        (void) ioctl(lfd, SIOCIPFFL, &n2);
                        break;
                }

                if (debuglevel > 3) {
                        printsynchdr(sh);
                        printcommand(sh->sm_cmd);
                        printtable(sh->sm_table);
                        putchar('\n');
                }

                if (bytes < sizeof(*sh) + len) {
                        debug(3, "Not enough bytes %d < %d\n", bytes,
                              sizeof(*sh) + len);
                        error = R_MORE;
                        break;
                }

                if (debuglevel > 3) {
                        printsmcproto(buf);
                }

                sendlen += len + sizeof(*sh);
                sh = (synchdr_t *)(buf + sendlen);
                bytes -= sendlen;
        }

        if (complete) {
                n3 = send(nfd, buf, sendlen, 0);
                if (n3 <= 0) {
                        syslog(LOG_ERR, "write error: %m");
                        debug(1, "write error: %s\n", STRERROR(errno));
                        return R_IO_ERROR;
                }
                debug(3, "send on %d len %d = %d\n", nfd, sendlen, n3);
                error = R_OKAY;
        }

        /* move buffer to the front,we might need to make
         * this more efficient, by using a rolling pointer
         * over the buffer and only copying it, when
         * we are reaching the end
         */
        if (bytes > 0) {
                bcopy(buf + bytes, buf, bytes);
                error = R_MORE;
        }
        debug(4, "complete %d bytes %d error %d\n", complete, bytes, error);

        *left = bytes;

        return error;
}


void
printcommand(cmd)
        int cmd;
{

        switch (cmd)
        {
        case SMC_CREATE :
                printf(" cmd:CREATE");
                break;
        case SMC_UPDATE :
                printf(" cmd:UPDATE");
                break;
        default :
                printf(" cmd:Unknown(%d)", cmd);
                break;
        }
}


void
printtable(table)
        int table;
{
        switch (table)
        {
        case SMC_NAT :
                printf(" table:NAT");
                break;
        case SMC_STATE :
                printf(" table:STATE");
                break;
        default :
                printf(" table:Unknown(%d)", table);
                break;
        }
}


void
printsmcproto(buff)
        char *buff;
{
        syncupdent_t *su;
        synchdr_t *sh;

        sh = (synchdr_t *)buff;

        if (sh->sm_cmd == SMC_CREATE) {
                ;

        } else if (sh->sm_cmd == SMC_UPDATE) {
                su = (syncupdent_t *)buff;
                if (sh->sm_p == IPPROTO_TCP) {
                        printf(" TCP Update: age %lu state %d/%d\n",
                                su->sup_tcp.stu_age,
                                su->sup_tcp.stu_state[0],
                                su->sup_tcp.stu_state[1]);
                }
        } else {
                printf("Unknown command\n");
        }
}


void
printsynchdr(sh)
        synchdr_t *sh;
{

        printf("v:%d p:%d num:%d len:%d magic:%x", sh->sm_v, sh->sm_p,
               ntohl(sh->sm_num), ntohl(sh->sm_len), ntohl(sh->sm_magic));
}