root/net/psp/psp_nl.c
// SPDX-License-Identifier: GPL-2.0-only

#include <linux/ethtool.h>
#include <linux/skbuff.h>
#include <linux/xarray.h>
#include <net/genetlink.h>
#include <net/psp.h>
#include <net/sock.h>

#include "psp-nl-gen.h"
#include "psp.h"

/* Netlink helpers */

static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
{
        struct sk_buff *rsp;
        void *hdr;

        rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!rsp)
                return NULL;

        hdr = genlmsg_iput(rsp, info);
        if (!hdr) {
                nlmsg_free(rsp);
                return NULL;
        }

        return rsp;
}

static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
{
        /* Note that this *only* works with a single message per skb! */
        nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);

        return genlmsg_reply(rsp, info);
}

/* Device stuff */

static struct psp_dev *
psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
{
        struct psp_dev *psd;
        int err;

        mutex_lock(&psp_devs_lock);
        psd = xa_load(&psp_devs, nla_get_u32(dev_id));
        if (!psd) {
                mutex_unlock(&psp_devs_lock);
                return ERR_PTR(-ENODEV);
        }

        mutex_lock(&psd->lock);
        mutex_unlock(&psp_devs_lock);

        err = psp_dev_check_access(psd, net);
        if (err) {
                mutex_unlock(&psd->lock);
                return ERR_PTR(err);
        }

        return psd;
}

int psp_device_get_locked(const struct genl_split_ops *ops,
                          struct sk_buff *skb, struct genl_info *info)
{
        if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
                return -EINVAL;

        info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
                                                    info->attrs[PSP_A_DEV_ID]);
        return PTR_ERR_OR_ZERO(info->user_ptr[0]);
}

void
psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
                  struct genl_info *info)
{
        struct socket *socket = info->user_ptr[1];
        struct psp_dev *psd = info->user_ptr[0];

        mutex_unlock(&psd->lock);
        if (socket)
                sockfd_put(socket);
}

static int
psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
                const struct genl_info *info)
{
        void *hdr;

        hdr = genlmsg_iput(rsp, info);
        if (!hdr)
                return -EMSGSIZE;

        if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
            nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
            nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
            nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
                goto err_cancel_msg;

        genlmsg_end(rsp, hdr);
        return 0;

err_cancel_msg:
        genlmsg_cancel(rsp, hdr);
        return -EMSGSIZE;
}

void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
{
        struct genl_info info;
        struct sk_buff *ntf;

        if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
                                PSP_NLGRP_MGMT))
                return;

        ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!ntf)
                return;

        genl_info_init_ntf(&info, &psp_nl_family, cmd);
        if (psp_nl_dev_fill(psd, ntf, &info)) {
                nlmsg_free(ntf);
                return;
        }

        genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
                                0, PSP_NLGRP_MGMT, GFP_KERNEL);
}

int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
{
        struct psp_dev *psd = info->user_ptr[0];
        struct sk_buff *rsp;
        int err;

        rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!rsp)
                return -ENOMEM;

        err = psp_nl_dev_fill(psd, rsp, info);
        if (err)
                goto err_free_msg;

        return genlmsg_reply(rsp, info);

err_free_msg:
        nlmsg_free(rsp);
        return err;
}

static int
psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
                          struct psp_dev *psd)
{
        if (psp_dev_check_access(psd, sock_net(rsp->sk)))
                return 0;

        return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
}

int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
{
        struct psp_dev *psd;
        int err = 0;

        mutex_lock(&psp_devs_lock);
        xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
                mutex_lock(&psd->lock);
                err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
                mutex_unlock(&psd->lock);
                if (err)
                        break;
        }
        mutex_unlock(&psp_devs_lock);

        return err;
}

int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct psp_dev *psd = info->user_ptr[0];
        struct psp_dev_config new_config;
        struct sk_buff *rsp;
        int err;

        memcpy(&new_config, &psd->config, sizeof(new_config));

        if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
                new_config.versions =
                        nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
                if (new_config.versions & ~psd->caps->versions) {
                        NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
                        return -EINVAL;
                }
        } else {
                NL_SET_ERR_MSG(info->extack, "No settings present");
                return -EINVAL;
        }

        rsp = psp_nl_reply_new(info);
        if (!rsp)
                return -ENOMEM;

        if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
                err = psd->ops->set_config(psd, &new_config, info->extack);
                if (err)
                        goto err_free_rsp;

                memcpy(&psd->config, &new_config, sizeof(new_config));
        }

        psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);

        return psp_nl_reply_send(rsp, info);

err_free_rsp:
        nlmsg_free(rsp);
        return err;
}

int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct psp_dev *psd = info->user_ptr[0];
        struct genl_info ntf_info;
        struct sk_buff *ntf, *rsp;
        u8 prev_gen;
        int err;

        rsp = psp_nl_reply_new(info);
        if (!rsp)
                return -ENOMEM;

        genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
        ntf = psp_nl_reply_new(&ntf_info);
        if (!ntf) {
                err = -ENOMEM;
                goto err_free_rsp;
        }

        if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
            nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
                err = -EMSGSIZE;
                goto err_free_ntf;
        }

        /* suggest the next gen number, driver can override */
        prev_gen = psd->generation;
        psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;

        err = psd->ops->key_rotate(psd, info->extack);
        if (err)
                goto err_free_ntf;

        WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
                     psd->generation & ~PSP_GEN_VALID_MASK);

        psp_assocs_key_rotated(psd);
        psd->stats.rotations++;

        nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
        genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
                                0, PSP_NLGRP_USE, GFP_KERNEL);
        return psp_nl_reply_send(rsp, info);

err_free_ntf:
        nlmsg_free(ntf);
err_free_rsp:
        nlmsg_free(rsp);
        return err;
}

/* Key etc. */

int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
                                struct sk_buff *skb, struct genl_info *info)
{
        struct socket *socket;
        struct psp_dev *psd;
        struct nlattr *id;
        int fd, err;

        if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
                return -EINVAL;

        fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
        socket = sockfd_lookup(fd, &err);
        if (!socket)
                return err;

        if (!sk_is_tcp(socket->sk)) {
                NL_SET_ERR_MSG_ATTR(info->extack,
                                    info->attrs[PSP_A_ASSOC_SOCK_FD],
                                    "Unsupported socket family and type");
                err = -EOPNOTSUPP;
                goto err_sock_put;
        }

        psd = psp_dev_get_for_sock(socket->sk);
        if (psd) {
                err = psp_dev_check_access(psd, genl_info_net(info));
                if (err) {
                        psp_dev_put(psd);
                        psd = NULL;
                }
        }

        if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
                err = -EINVAL;
                goto err_sock_put;
        }

        id = info->attrs[PSP_A_ASSOC_DEV_ID];
        if (psd) {
                mutex_lock(&psd->lock);
                if (id && psd->id != nla_get_u32(id)) {
                        mutex_unlock(&psd->lock);
                        NL_SET_ERR_MSG_ATTR(info->extack, id,
                                            "Device id vs socket mismatch");
                        err = -EINVAL;
                        goto err_psd_put;
                }

                psp_dev_put(psd);
        } else {
                psd = psp_device_get_and_lock(genl_info_net(info), id);
                if (IS_ERR(psd)) {
                        err = PTR_ERR(psd);
                        goto err_sock_put;
                }
        }

        info->user_ptr[0] = psd;
        info->user_ptr[1] = socket;

        return 0;

err_psd_put:
        psp_dev_put(psd);
err_sock_put:
        sockfd_put(socket);
        return err;
}

static int
psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
                 unsigned int key_sz)
{
        struct nlattr *nest = info->attrs[attr];
        struct nlattr *tb[PSP_A_KEYS_SPI + 1];
        u32 spi;
        int err;

        err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
                               psp_keys_nl_policy, info->extack);
        if (err)
                return err;

        if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
            NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
                return -EINVAL;

        if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
                NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
                                    "incorrect key length");
                return -EINVAL;
        }

        spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
        if (!(spi & PSP_SPI_KEY_ID)) {
                NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
                                    "invalid SPI: lower 31b must be non-zero");
                return -EINVAL;
        }

        key->spi = cpu_to_be32(spi);
        memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);

        return 0;
}

static int
psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
               struct psp_key_parsed *key)
{
        int key_sz = psp_key_size(version);
        void *nest;

        nest = nla_nest_start(skb, attr);

        if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
            nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
                nla_nest_cancel(skb, nest);
                return -EMSGSIZE;
        }

        nla_nest_end(skb, nest);

        return 0;
}

int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct socket *socket = info->user_ptr[1];
        struct psp_dev *psd = info->user_ptr[0];
        struct psp_key_parsed key;
        struct psp_assoc *pas;
        struct sk_buff *rsp;
        u32 version;
        int err;

        if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
                return -EINVAL;

        version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
        if (!(psd->caps->versions & (1 << version))) {
                NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
                return -EOPNOTSUPP;
        }

        rsp = psp_nl_reply_new(info);
        if (!rsp)
                return -ENOMEM;

        pas = psp_assoc_create(psd);
        if (!pas) {
                err = -ENOMEM;
                goto err_free_rsp;
        }
        pas->version = version;

        err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
        if (err)
                goto err_free_pas;

        if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
            psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
                err = -EMSGSIZE;
                goto err_free_pas;
        }

        err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
        if (err) {
                NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
                goto err_free_pas;
        }
        psp_assoc_put(pas);

        return psp_nl_reply_send(rsp, info);

err_free_pas:
        psp_assoc_put(pas);
err_free_rsp:
        nlmsg_free(rsp);
        return err;
}

int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct socket *socket = info->user_ptr[1];
        struct psp_dev *psd = info->user_ptr[0];
        struct psp_key_parsed key;
        struct sk_buff *rsp;
        unsigned int key_sz;
        u32 version;
        int err;

        if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
            GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
                return -EINVAL;

        version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
        if (!(psd->caps->versions & (1 << version))) {
                NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
                return -EOPNOTSUPP;
        }

        key_sz = psp_key_size(version);
        if (!key_sz)
                return -EINVAL;

        err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
        if (err < 0)
                return err;

        rsp = psp_nl_reply_new(info);
        if (!rsp)
                return -ENOMEM;

        err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
                                    info->extack);
        if (err)
                goto err_free_msg;

        return psp_nl_reply_send(rsp, info);

err_free_msg:
        nlmsg_free(rsp);
        return err;
}

static int
psp_nl_stats_fill(struct psp_dev *psd, struct sk_buff *rsp,
                  const struct genl_info *info)
{
        unsigned int required_cnt = sizeof(struct psp_dev_stats) / sizeof(u64);
        struct psp_dev_stats stats;
        void *hdr;
        int i;

        memset(&stats, 0xff, sizeof(stats));
        psd->ops->get_stats(psd, &stats);

        for (i = 0; i < required_cnt; i++)
                if (WARN_ON_ONCE(stats.required[i] == ETHTOOL_STAT_NOT_SET))
                        return -EOPNOTSUPP;

        hdr = genlmsg_iput(rsp, info);
        if (!hdr)
                return -EMSGSIZE;

        if (nla_put_u32(rsp, PSP_A_STATS_DEV_ID, psd->id) ||
            nla_put_uint(rsp, PSP_A_STATS_KEY_ROTATIONS,
                         psd->stats.rotations) ||
            nla_put_uint(rsp, PSP_A_STATS_STALE_EVENTS, psd->stats.stales) ||
            nla_put_uint(rsp, PSP_A_STATS_RX_PACKETS, stats.rx_packets) ||
            nla_put_uint(rsp, PSP_A_STATS_RX_BYTES, stats.rx_bytes) ||
            nla_put_uint(rsp, PSP_A_STATS_RX_AUTH_FAIL, stats.rx_auth_fail) ||
            nla_put_uint(rsp, PSP_A_STATS_RX_ERROR, stats.rx_error) ||
            nla_put_uint(rsp, PSP_A_STATS_RX_BAD, stats.rx_bad) ||
            nla_put_uint(rsp, PSP_A_STATS_TX_PACKETS, stats.tx_packets) ||
            nla_put_uint(rsp, PSP_A_STATS_TX_BYTES, stats.tx_bytes) ||
            nla_put_uint(rsp, PSP_A_STATS_TX_ERROR, stats.tx_error))
                goto err_cancel_msg;

        genlmsg_end(rsp, hdr);
        return 0;

err_cancel_msg:
        genlmsg_cancel(rsp, hdr);
        return -EMSGSIZE;
}

int psp_nl_get_stats_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct psp_dev *psd = info->user_ptr[0];
        struct sk_buff *rsp;
        int err;

        rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!rsp)
                return -ENOMEM;

        err = psp_nl_stats_fill(psd, rsp, info);
        if (err)
                goto err_free_msg;

        return genlmsg_reply(rsp, info);

err_free_msg:
        nlmsg_free(rsp);
        return err;
}

static int
psp_nl_stats_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
                            struct psp_dev *psd)
{
        if (psp_dev_check_access(psd, sock_net(rsp->sk)))
                return 0;

        return psp_nl_stats_fill(psd, rsp, genl_info_dump(cb));
}

int psp_nl_get_stats_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
{
        struct psp_dev *psd;
        int err = 0;

        mutex_lock(&psp_devs_lock);
        xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
                mutex_lock(&psd->lock);
                err = psp_nl_stats_get_dumpit_one(rsp, cb, psd);
                mutex_unlock(&psd->lock);
                if (err)
                        break;
        }
        mutex_unlock(&psp_devs_lock);

        return err;
}