root/drivers/net/ethernet/netronome/nfp/crypto/tls.c
// SPDX-License-Identifier: (GPL-2.0-only OR BSD-2-Clause)
/* Copyright (C) 2019 Netronome Systems, Inc. */

#include <linux/bitfield.h>
#include <linux/ipv6.h>
#include <linux/skbuff.h>
#include <linux/string.h>
#include <net/inet6_hashtables.h>
#include <net/tls.h>

#include "../ccm.h"
#include "../nfp_net.h"
#include "crypto.h"
#include "fw.h"

#define NFP_NET_TLS_CCM_MBOX_OPS_MASK           \
        (BIT(NFP_CCM_TYPE_CRYPTO_RESET) |       \
         BIT(NFP_CCM_TYPE_CRYPTO_ADD) |         \
         BIT(NFP_CCM_TYPE_CRYPTO_DEL) |         \
         BIT(NFP_CCM_TYPE_CRYPTO_UPDATE))

#define NFP_NET_TLS_OPCODE_MASK_RX                      \
        BIT(NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_DEC)

#define NFP_NET_TLS_OPCODE_MASK_TX                      \
        BIT(NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_ENC)

#define NFP_NET_TLS_OPCODE_MASK                                         \
        (NFP_NET_TLS_OPCODE_MASK_RX | NFP_NET_TLS_OPCODE_MASK_TX)

static void nfp_net_crypto_set_op(struct nfp_net *nn, u8 opcode, bool on)
{
        u32 off, val;

        off = nn->tlv_caps.crypto_enable_off + round_down(opcode / 8, 4);

        val = nn_readl(nn, off);
        if (on)
                val |= BIT(opcode & 31);
        else
                val &= ~BIT(opcode & 31);
        nn_writel(nn, off, val);
}

static bool
__nfp_net_tls_conn_cnt_changed(struct nfp_net *nn, int add,
                               enum tls_offload_ctx_dir direction)
{
        u8 opcode;
        int cnt;

        if (direction == TLS_OFFLOAD_CTX_DIR_TX) {
                opcode = NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_ENC;
                nn->ktls_tx_conn_cnt += add;
                cnt = nn->ktls_tx_conn_cnt;
                nn->dp.ktls_tx = !!nn->ktls_tx_conn_cnt;
        } else {
                opcode = NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_DEC;
                nn->ktls_rx_conn_cnt += add;
                cnt = nn->ktls_rx_conn_cnt;
        }

        /* Care only about 0 -> 1 and 1 -> 0 transitions */
        if (cnt > 1)
                return false;

        nfp_net_crypto_set_op(nn, opcode, cnt);
        return true;
}

static int
nfp_net_tls_conn_cnt_changed(struct nfp_net *nn, int add,
                             enum tls_offload_ctx_dir direction)
{
        int ret = 0;

        /* Use the BAR lock to protect the connection counts */
        nn_ctrl_bar_lock(nn);
        if (__nfp_net_tls_conn_cnt_changed(nn, add, direction)) {
                ret = __nfp_net_reconfig(nn, NFP_NET_CFG_UPDATE_CRYPTO);
                /* Undo the cnt adjustment if failed */
                if (ret)
                        __nfp_net_tls_conn_cnt_changed(nn, -add, direction);
        }
        nn_ctrl_bar_unlock(nn);

        return ret;
}

static int
nfp_net_tls_conn_add(struct nfp_net *nn, enum tls_offload_ctx_dir direction)
{
        return nfp_net_tls_conn_cnt_changed(nn, 1, direction);
}

static int
nfp_net_tls_conn_remove(struct nfp_net *nn, enum tls_offload_ctx_dir direction)
{
        return nfp_net_tls_conn_cnt_changed(nn, -1, direction);
}

static struct sk_buff *
nfp_net_tls_alloc_simple(struct nfp_net *nn, size_t req_sz, gfp_t flags)
{
        return nfp_ccm_mbox_msg_alloc(nn, req_sz,
                                      sizeof(struct nfp_crypto_reply_simple),
                                      flags);
}

static int
nfp_net_tls_communicate_simple(struct nfp_net *nn, struct sk_buff *skb,
                               const char *name, enum nfp_ccm_type type)
{
        struct nfp_crypto_reply_simple *reply;
        int err;

        err = __nfp_ccm_mbox_communicate(nn, skb, type,
                                         sizeof(*reply), sizeof(*reply),
                                         type == NFP_CCM_TYPE_CRYPTO_DEL);
        if (err) {
                nn_dp_warn(&nn->dp, "failed to %s TLS: %d\n", name, err);
                return err;
        }

        reply = (void *)skb->data;
        err = -be32_to_cpu(reply->error);
        if (err)
                nn_dp_warn(&nn->dp, "failed to %s TLS, fw replied: %d\n",
                           name, err);
        dev_consume_skb_any(skb);

        return err;
}

static void nfp_net_tls_del_fw(struct nfp_net *nn, __be32 *fw_handle)
{
        struct nfp_crypto_req_del *req;
        struct sk_buff *skb;

        skb = nfp_net_tls_alloc_simple(nn, sizeof(*req), GFP_KERNEL);
        if (!skb)
                return;

        req = (void *)skb->data;
        req->ep_id = 0;
        memcpy(req->handle, fw_handle, sizeof(req->handle));

        nfp_net_tls_communicate_simple(nn, skb, "delete",
                                       NFP_CCM_TYPE_CRYPTO_DEL);
}

static void
nfp_net_tls_set_ipver_vlan(struct nfp_crypto_req_add_front *front, u8 ipver)
{
        front->ipver_vlan = cpu_to_be16(FIELD_PREP(NFP_NET_TLS_IPVER, ipver) |
                                        FIELD_PREP(NFP_NET_TLS_VLAN,
                                                   NFP_NET_TLS_VLAN_UNUSED));
}

static void
nfp_net_tls_assign_conn_id(struct nfp_net *nn,
                           struct nfp_crypto_req_add_front *front)
{
        u32 len;
        u64 id;

        id = atomic64_inc_return(&nn->ktls_conn_id_gen);
        len = front->key_len - NFP_NET_TLS_NON_ADDR_KEY_LEN;

        memcpy(front->l3_addrs, &id, sizeof(id));
        memset(front->l3_addrs + sizeof(id), 0, len - sizeof(id));
}

static struct nfp_crypto_req_add_back *
nfp_net_tls_set_ipv4(struct nfp_net *nn, struct nfp_crypto_req_add_v4 *req,
                     struct sock *sk, int direction)
{
        struct inet_sock *inet = inet_sk(sk);

        req->front.key_len += sizeof(__be32) * 2;

        if (direction == TLS_OFFLOAD_CTX_DIR_TX) {
                nfp_net_tls_assign_conn_id(nn,
                        container_of(&req->front,
                                     struct nfp_crypto_req_add_front, __hdr));
        } else {
                req->src_ip = inet->inet_daddr;
                req->dst_ip = inet->inet_saddr;
        }

        return &req->back;
}

static struct nfp_crypto_req_add_back *
nfp_net_tls_set_ipv6(struct nfp_net *nn, struct nfp_crypto_req_add_v6 *req,
                     struct sock *sk, int direction)
{
#if IS_ENABLED(CONFIG_IPV6)
        struct ipv6_pinfo *np = inet6_sk(sk);

        req->front.key_len += sizeof(struct in6_addr) * 2;

        if (direction == TLS_OFFLOAD_CTX_DIR_TX) {
                nfp_net_tls_assign_conn_id(nn,
                        container_of(&req->front,
                                     struct nfp_crypto_req_add_front, __hdr));
        } else {
                memcpy(req->src_ip, &sk->sk_v6_daddr, sizeof(req->src_ip));
                memcpy(req->dst_ip, &np->saddr, sizeof(req->dst_ip));
        }

#endif
        return &req->back;
}

static void
nfp_net_tls_set_l4(struct nfp_crypto_req_add_front *front,
                   struct nfp_crypto_req_add_back *back, struct sock *sk,
                   int direction)
{
        struct inet_sock *inet = inet_sk(sk);

        front->l4_proto = IPPROTO_TCP;

        if (direction == TLS_OFFLOAD_CTX_DIR_TX) {
                back->src_port = 0;
                back->dst_port = 0;
        } else {
                back->src_port = inet->inet_dport;
                back->dst_port = inet->inet_sport;
        }
}

static u8 nfp_tls_1_2_dir_to_opcode(enum tls_offload_ctx_dir direction)
{
        switch (direction) {
        case TLS_OFFLOAD_CTX_DIR_TX:
                return NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_ENC;
        case TLS_OFFLOAD_CTX_DIR_RX:
                return NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_DEC;
        default:
                WARN_ON_ONCE(1);
                return 0;
        }
}

static bool
nfp_net_cipher_supported(struct nfp_net *nn, u16 cipher_type,
                         enum tls_offload_ctx_dir direction)
{
        u8 bit;

        switch (cipher_type) {
        case TLS_CIPHER_AES_GCM_128:
                if (direction == TLS_OFFLOAD_CTX_DIR_TX)
                        bit = NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_ENC;
                else
                        bit = NFP_NET_CRYPTO_OP_TLS_1_2_AES_GCM_128_DEC;
                break;
        default:
                return false;
        }

        return nn->tlv_caps.crypto_ops & BIT(bit);
}

static int
nfp_net_tls_add(struct net_device *netdev, struct sock *sk,
                enum tls_offload_ctx_dir direction,
                struct tls_crypto_info *crypto_info,
                u32 start_offload_tcp_sn)
{
        struct tls12_crypto_info_aes_gcm_128 *tls_ci;
        struct nfp_net *nn = netdev_priv(netdev);
        struct nfp_crypto_req_add_front *front;
        struct nfp_net_tls_offload_ctx *ntls;
        struct nfp_crypto_req_add_back *back;
        struct nfp_crypto_reply_add *reply;
        struct sk_buff *skb;
        size_t req_sz;
        void *req;
        bool ipv6;
        int err;

        BUILD_BUG_ON(sizeof(struct nfp_net_tls_offload_ctx) >
                     TLS_DRIVER_STATE_SIZE_TX);
        BUILD_BUG_ON(offsetof(struct nfp_net_tls_offload_ctx, rx_end) >
                     TLS_DRIVER_STATE_SIZE_RX);

        if (!nfp_net_cipher_supported(nn, crypto_info->cipher_type, direction))
                return -EOPNOTSUPP;

        switch (sk->sk_family) {
#if IS_ENABLED(CONFIG_IPV6)
        case AF_INET6:
                if (ipv6_only_sock(sk) ||
                    ipv6_addr_type(&sk->sk_v6_daddr) != IPV6_ADDR_MAPPED) {
                        req_sz = sizeof(struct nfp_crypto_req_add_v6);
                        ipv6 = true;
                        break;
                }
                fallthrough;
#endif
        case AF_INET:
                req_sz = sizeof(struct nfp_crypto_req_add_v4);
                ipv6 = false;
                break;
        default:
                return -EOPNOTSUPP;
        }

        err = nfp_net_tls_conn_add(nn, direction);
        if (err)
                return err;

        skb = nfp_ccm_mbox_msg_alloc(nn, req_sz, sizeof(*reply), GFP_KERNEL);
        if (!skb) {
                err = -ENOMEM;
                goto err_conn_remove;
        }

        front = (void *)skb->data;
        front->ep_id = 0;
        front->key_len = NFP_NET_TLS_NON_ADDR_KEY_LEN;
        front->opcode = nfp_tls_1_2_dir_to_opcode(direction);
        memset(front->resv, 0, sizeof(front->resv));

        nfp_net_tls_set_ipver_vlan(front, ipv6 ? 6 : 4);

        req = (void *)skb->data;
        if (ipv6)
                back = nfp_net_tls_set_ipv6(nn, req, sk, direction);
        else
                back = nfp_net_tls_set_ipv4(nn, req, sk, direction);

        nfp_net_tls_set_l4(front, back, sk, direction);

        back->counter = 0;
        back->tcp_seq = cpu_to_be32(start_offload_tcp_sn);

        tls_ci = (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
        memcpy(back->key, tls_ci->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
        memset(&back->key[TLS_CIPHER_AES_GCM_128_KEY_SIZE / 4], 0,
               sizeof(back->key) - TLS_CIPHER_AES_GCM_128_KEY_SIZE);
        memcpy(back->iv, tls_ci->iv, TLS_CIPHER_AES_GCM_128_IV_SIZE);
        memcpy(&back->salt, tls_ci->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
        memcpy(back->rec_no, tls_ci->rec_seq, sizeof(tls_ci->rec_seq));

        /* Get an extra ref on the skb so we can wipe the key after */
        skb_get(skb);

        err = nfp_ccm_mbox_communicate(nn, skb, NFP_CCM_TYPE_CRYPTO_ADD,
                                       sizeof(*reply), sizeof(*reply));
        reply = (void *)skb->data;

        /* We depend on CCM MBOX code not reallocating skb we sent
         * so we can clear the key material out of the memory.
         */
        if (!WARN_ON_ONCE((u8 *)back < skb->head ||
                          (u8 *)back > skb_end_pointer(skb)) &&
            !WARN_ON_ONCE((u8 *)&reply[1] > (u8 *)back))
                memzero_explicit(back, sizeof(*back));
        dev_consume_skb_any(skb); /* the extra ref from skb_get() above */

        if (err) {
                nn_dp_warn(&nn->dp, "failed to add TLS: %d (%d)\n",
                           err, direction == TLS_OFFLOAD_CTX_DIR_TX);
                /* communicate frees skb on error */
                goto err_conn_remove;
        }

        err = -be32_to_cpu(reply->error);
        if (err) {
                if (err == -ENOSPC) {
                        if (!atomic_fetch_inc(&nn->ktls_no_space))
                                nn_info(nn, "HW TLS table full\n");
                } else {
                        nn_dp_warn(&nn->dp,
                                   "failed to add TLS, FW replied: %d\n", err);
                }
                goto err_free_skb;
        }

        if (!reply->handle[0] && !reply->handle[1]) {
                nn_dp_warn(&nn->dp, "FW returned NULL handle\n");
                err = -EINVAL;
                goto err_fw_remove;
        }

        ntls = tls_driver_ctx(sk, direction);
        memcpy(ntls->fw_handle, reply->handle, sizeof(ntls->fw_handle));
        if (direction == TLS_OFFLOAD_CTX_DIR_TX)
                ntls->next_seq = start_offload_tcp_sn;
        dev_consume_skb_any(skb);

        if (direction == TLS_OFFLOAD_CTX_DIR_TX)
                return 0;

        if (!nn->tlv_caps.tls_resync_ss)
                tls_offload_rx_resync_set_type(sk, TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT);

        return 0;

err_fw_remove:
        nfp_net_tls_del_fw(nn, reply->handle);
err_free_skb:
        dev_consume_skb_any(skb);
err_conn_remove:
        nfp_net_tls_conn_remove(nn, direction);
        return err;
}

static void
nfp_net_tls_del(struct net_device *netdev, struct tls_context *tls_ctx,
                enum tls_offload_ctx_dir direction)
{
        struct nfp_net *nn = netdev_priv(netdev);
        struct nfp_net_tls_offload_ctx *ntls;

        nfp_net_tls_conn_remove(nn, direction);

        ntls = __tls_driver_ctx(tls_ctx, direction);
        nfp_net_tls_del_fw(nn, ntls->fw_handle);
}

static int
nfp_net_tls_resync(struct net_device *netdev, struct sock *sk, u32 seq,
                   u8 *rcd_sn, enum tls_offload_ctx_dir direction)
{
        struct nfp_net *nn = netdev_priv(netdev);
        struct nfp_net_tls_offload_ctx *ntls;
        struct nfp_crypto_req_update *req;
        enum nfp_ccm_type type;
        struct sk_buff *skb;
        gfp_t flags;
        int err;

        flags = direction == TLS_OFFLOAD_CTX_DIR_TX ? GFP_KERNEL : GFP_ATOMIC;
        skb = nfp_net_tls_alloc_simple(nn, sizeof(*req), flags);
        if (!skb)
                return -ENOMEM;

        ntls = tls_driver_ctx(sk, direction);
        req = (void *)skb->data;
        req->ep_id = 0;
        req->opcode = nfp_tls_1_2_dir_to_opcode(direction);
        memset(req->resv, 0, sizeof(req->resv));
        memcpy(req->handle, ntls->fw_handle, sizeof(ntls->fw_handle));
        req->tcp_seq = cpu_to_be32(seq);
        memcpy(req->rec_no, rcd_sn, sizeof(req->rec_no));

        type = NFP_CCM_TYPE_CRYPTO_UPDATE;
        if (direction == TLS_OFFLOAD_CTX_DIR_TX) {
                err = nfp_net_tls_communicate_simple(nn, skb, "sync", type);
                if (err)
                        return err;
                ntls->next_seq = seq;
        } else {
                if (nn->tlv_caps.tls_resync_ss)
                        type = NFP_CCM_TYPE_CRYPTO_RESYNC;
                nfp_ccm_mbox_post(nn, skb, type,
                                  sizeof(struct nfp_crypto_reply_simple));
                atomic_inc(&nn->ktls_rx_resync_sent);
        }

        return 0;
}

static const struct tlsdev_ops nfp_net_tls_ops = {
        .tls_dev_add = nfp_net_tls_add,
        .tls_dev_del = nfp_net_tls_del,
        .tls_dev_resync = nfp_net_tls_resync,
};

int nfp_net_tls_rx_resync_req(struct net_device *netdev,
                              struct nfp_net_tls_resync_req *req,
                              void *pkt, unsigned int pkt_len)
{
        struct nfp_net *nn = netdev_priv(netdev);
        struct nfp_net_tls_offload_ctx *ntls;
        struct net *net = dev_net(netdev);
        struct ipv6hdr *ipv6h;
        struct tcphdr *th;
        struct iphdr *iph;
        struct sock *sk;
        __be32 tcp_seq;
        int err;

        iph = pkt + req->l3_offset;
        ipv6h = pkt + req->l3_offset;
        th = pkt + req->l4_offset;

        if ((u8 *)&th[1] > (u8 *)pkt + pkt_len) {
                netdev_warn_once(netdev, "invalid TLS RX resync request (l3_off: %hhu l4_off: %hhu pkt_len: %u)\n",
                                 req->l3_offset, req->l4_offset, pkt_len);
                err = -EINVAL;
                goto err_cnt_ign;
        }

        switch (ipv6h->version) {
        case 4:
                sk = inet_lookup_established(net, iph->saddr, th->source,
                                             iph->daddr, th->dest,
                                             netdev->ifindex);
                break;
#if IS_ENABLED(CONFIG_IPV6)
        case 6:
                sk = __inet6_lookup_established(net, &ipv6h->saddr, th->source,
                                                &ipv6h->daddr, ntohs(th->dest),
                                                netdev->ifindex, 0);
                break;
#endif
        default:
                netdev_warn_once(netdev, "invalid TLS RX resync request (l3_off: %hhu l4_off: %hhu ipver: %u)\n",
                                 req->l3_offset, req->l4_offset, iph->version);
                err = -EINVAL;
                goto err_cnt_ign;
        }

        err = 0;
        if (!sk)
                goto err_cnt_ign;
        if (!tls_is_sk_rx_device_offloaded(sk) ||
            sk->sk_shutdown & RCV_SHUTDOWN)
                goto err_put_sock;

        ntls = tls_driver_ctx(sk, TLS_OFFLOAD_CTX_DIR_RX);
        /* some FW versions can't report the handle and report 0s */
        if (memchr_inv(&req->fw_handle, 0, sizeof(req->fw_handle)) &&
            memcmp(&req->fw_handle, &ntls->fw_handle, sizeof(ntls->fw_handle)))
                goto err_put_sock;

        /* copy to ensure alignment */
        memcpy(&tcp_seq, &req->tcp_seq, sizeof(tcp_seq));
        tls_offload_rx_resync_request(sk, tcp_seq);
        atomic_inc(&nn->ktls_rx_resync_req);

        sock_gen_put(sk);
        return 0;

err_put_sock:
        sock_gen_put(sk);
err_cnt_ign:
        atomic_inc(&nn->ktls_rx_resync_ign);
        return err;
}

static int nfp_net_tls_reset(struct nfp_net *nn)
{
        struct nfp_crypto_req_reset *req;
        struct sk_buff *skb;

        skb = nfp_net_tls_alloc_simple(nn, sizeof(*req), GFP_KERNEL);
        if (!skb)
                return -ENOMEM;

        req = (void *)skb->data;
        req->ep_id = 0;

        return nfp_net_tls_communicate_simple(nn, skb, "reset",
                                              NFP_CCM_TYPE_CRYPTO_RESET);
}

int nfp_net_tls_init(struct nfp_net *nn)
{
        struct net_device *netdev = nn->dp.netdev;
        int err;

        if (!(nn->tlv_caps.crypto_ops & NFP_NET_TLS_OPCODE_MASK))
                return 0;

        if ((nn->tlv_caps.mbox_cmsg_types & NFP_NET_TLS_CCM_MBOX_OPS_MASK) !=
            NFP_NET_TLS_CCM_MBOX_OPS_MASK)
                return 0;

        if (!nfp_ccm_mbox_fits(nn, sizeof(struct nfp_crypto_req_add_v6))) {
                nn_warn(nn, "disabling TLS offload - mbox too small: %d\n",
                        nn->tlv_caps.mbox_len);
                return 0;
        }

        err = nfp_net_tls_reset(nn);
        if (err)
                return err;

        nn_ctrl_bar_lock(nn);
        nn_writel(nn, nn->tlv_caps.crypto_enable_off, 0);
        err = __nfp_net_reconfig(nn, NFP_NET_CFG_UPDATE_CRYPTO);
        nn_ctrl_bar_unlock(nn);
        if (err)
                return err;

        if (nn->tlv_caps.crypto_ops & NFP_NET_TLS_OPCODE_MASK_RX) {
                netdev->hw_features |= NETIF_F_HW_TLS_RX;
                netdev->features |= NETIF_F_HW_TLS_RX;
        }
        if (nn->tlv_caps.crypto_ops & NFP_NET_TLS_OPCODE_MASK_TX) {
                netdev->hw_features |= NETIF_F_HW_TLS_TX;
                netdev->features |= NETIF_F_HW_TLS_TX;
        }

        netdev->tlsdev_ops = &nfp_net_tls_ops;

        return 0;
}