root/net/handshake/request.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Handshake request lifetime events
 *
 * Author: Chuck Lever <chuck.lever@oracle.com>
 *
 * Copyright (c) 2023, Oracle and/or its affiliates.
 */

#include <linux/types.h>
#include <linux/socket.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/inet.h>
#include <linux/rhashtable.h>

#include <net/sock.h>
#include <net/genetlink.h>
#include <net/netns/generic.h>

#include <kunit/visibility.h>

#include <uapi/linux/handshake.h>
#include "handshake.h"

#include <trace/events/handshake.h>

/*
 * We need both a handshake_req -> sock mapping, and a sock ->
 * handshake_req mapping. Both are one-to-one.
 *
 * To avoid adding another pointer field to struct sock, net/handshake
 * maintains a hash table, indexed by the memory address of @sock, to
 * find the struct handshake_req outstanding for that socket. The
 * reverse direction uses a simple pointer field in the handshake_req
 * struct.
 */

static struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp;

static const struct rhashtable_params handshake_rhash_params = {
        .key_len                = sizeof_field(struct handshake_req, hr_sk),
        .key_offset             = offsetof(struct handshake_req, hr_sk),
        .head_offset            = offsetof(struct handshake_req, hr_rhash),
        .automatic_shrinking    = true,
};

int handshake_req_hash_init(void)
{
        return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params);
}

void handshake_req_hash_destroy(void)
{
        rhashtable_destroy(&handshake_rhashtbl);
}

struct handshake_req *handshake_req_hash_lookup(struct sock *sk)
{
        return rhashtable_lookup_fast(&handshake_rhashtbl, &sk,
                                      handshake_rhash_params);
}
EXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup);

static bool handshake_req_hash_add(struct handshake_req *req)
{
        int ret;

        ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl,
                                            &req->hr_rhash,
                                            handshake_rhash_params);
        return ret == 0;
}

static void handshake_req_destroy(struct handshake_req *req)
{
        if (req->hr_proto->hp_destroy)
                req->hr_proto->hp_destroy(req);
        rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash,
                               handshake_rhash_params);
        kfree(req);
}

static void handshake_sk_destruct(struct sock *sk)
{
        void (*sk_destruct)(struct sock *sk);
        struct handshake_req *req;

        req = handshake_req_hash_lookup(sk);
        if (!req)
                return;

        trace_handshake_destruct(sock_net(sk), req, sk);
        sk_destruct = req->hr_odestruct;
        handshake_req_destroy(req);
        if (sk_destruct)
                sk_destruct(sk);
}

/**
 * handshake_req_alloc - Allocate a handshake request
 * @proto: security protocol
 * @flags: memory allocation flags
 *
 * Returns an initialized handshake_req or NULL.
 */
struct handshake_req *handshake_req_alloc(const struct handshake_proto *proto,
                                          gfp_t flags)
{
        struct handshake_req *req;

        if (!proto)
                return NULL;
        if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE)
                return NULL;
        if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX)
                return NULL;
        if (!proto->hp_accept || !proto->hp_done)
                return NULL;

        req = kzalloc_flex(*req, hr_priv, proto->hp_privsize, flags);
        if (!req)
                return NULL;

        INIT_LIST_HEAD(&req->hr_list);
        req->hr_proto = proto;
        return req;
}
EXPORT_SYMBOL(handshake_req_alloc);

/**
 * handshake_req_private - Get per-handshake private data
 * @req: handshake arguments
 *
 */
void *handshake_req_private(struct handshake_req *req)
{
        return (void *)&req->hr_priv;
}
EXPORT_SYMBOL(handshake_req_private);

static bool __add_pending_locked(struct handshake_net *hn,
                                 struct handshake_req *req)
{
        if (WARN_ON_ONCE(!list_empty(&req->hr_list)))
                return false;
        hn->hn_pending++;
        list_add_tail(&req->hr_list, &hn->hn_requests);
        return true;
}

static void __remove_pending_locked(struct handshake_net *hn,
                                    struct handshake_req *req)
{
        hn->hn_pending--;
        list_del_init(&req->hr_list);
}

/*
 * Returns %true if the request was found on @net's pending list,
 * otherwise %false.
 *
 * If @req was on a pending list, it has not yet been accepted.
 */
static bool remove_pending(struct handshake_net *hn, struct handshake_req *req)
{
        bool ret = false;

        spin_lock(&hn->hn_lock);
        if (!list_empty(&req->hr_list)) {
                __remove_pending_locked(hn, req);
                ret = true;
        }
        spin_unlock(&hn->hn_lock);

        return ret;
}

struct handshake_req *handshake_req_next(struct handshake_net *hn, int class)
{
        struct handshake_req *req, *pos;

        req = NULL;
        spin_lock(&hn->hn_lock);
        list_for_each_entry(pos, &hn->hn_requests, hr_list) {
                if (pos->hr_proto->hp_handler_class != class)
                        continue;
                __remove_pending_locked(hn, pos);
                req = pos;
                break;
        }
        spin_unlock(&hn->hn_lock);

        return req;
}
EXPORT_SYMBOL_IF_KUNIT(handshake_req_next);

/**
 * handshake_req_submit - Submit a handshake request
 * @sock: open socket on which to perform the handshake
 * @req: handshake arguments
 * @flags: memory allocation flags
 *
 * Return values:
 *   %0: Request queued
 *   %-EINVAL: Invalid argument
 *   %-EBUSY: A handshake is already under way for this socket
 *   %-ESRCH: No handshake agent is available
 *   %-EAGAIN: Too many pending handshake requests
 *   %-ENOMEM: Failed to allocate memory
 *   %-EMSGSIZE: Failed to construct notification message
 *   %-EOPNOTSUPP: Handshake module not initialized
 *
 * A zero return value from handshake_req_submit() means that
 * exactly one subsequent completion callback is guaranteed.
 *
 * A negative return value from handshake_req_submit() means that
 * no completion callback will be done and that @req has been
 * destroyed.
 */
int handshake_req_submit(struct socket *sock, struct handshake_req *req,
                         gfp_t flags)
{
        struct handshake_net *hn;
        struct net *net;
        int ret;

        if (!sock || !req || !sock->file) {
                kfree(req);
                return -EINVAL;
        }

        req->hr_sk = sock->sk;
        if (!req->hr_sk) {
                kfree(req);
                return -EINVAL;
        }
        req->hr_odestruct = req->hr_sk->sk_destruct;
        req->hr_sk->sk_destruct = handshake_sk_destruct;

        ret = -EOPNOTSUPP;
        net = sock_net(req->hr_sk);
        hn = handshake_pernet(net);
        if (!hn)
                goto out_err;

        ret = -EAGAIN;
        if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max)
                goto out_err;

        spin_lock(&hn->hn_lock);
        ret = -EOPNOTSUPP;
        if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags))
                goto out_unlock;
        ret = -EBUSY;
        if (!handshake_req_hash_add(req))
                goto out_unlock;
        if (!__add_pending_locked(hn, req))
                goto out_unlock;
        spin_unlock(&hn->hn_lock);

        ret = handshake_genl_notify(net, req->hr_proto, flags);
        if (ret) {
                trace_handshake_notify_err(net, req, req->hr_sk, ret);
                if (remove_pending(hn, req))
                        goto out_err;
        }

        /* Prevent socket release while a handshake request is pending */
        sock_hold(req->hr_sk);

        trace_handshake_submit(net, req, req->hr_sk);
        return 0;

out_unlock:
        spin_unlock(&hn->hn_lock);
out_err:
        /* Restore original destructor so socket teardown still runs on failure */
        req->hr_sk->sk_destruct = req->hr_odestruct;
        trace_handshake_submit_err(net, req, req->hr_sk, ret);
        handshake_req_destroy(req);
        return ret;
}
EXPORT_SYMBOL(handshake_req_submit);

void handshake_complete(struct handshake_req *req, unsigned int status,
                        struct genl_info *info)
{
        struct sock *sk = req->hr_sk;
        struct net *net = sock_net(sk);

        if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
                trace_handshake_complete(net, req, sk, status);
                req->hr_proto->hp_done(req, status, info);

                /* Handshake request is no longer pending */
                sock_put(sk);
        }
}
EXPORT_SYMBOL_IF_KUNIT(handshake_complete);

/**
 * handshake_req_cancel - Cancel an in-progress handshake
 * @sk: socket on which there is an ongoing handshake
 *
 * Request cancellation races with request completion. To determine
 * who won, callers examine the return value from this function.
 *
 * Return values:
 *   %true - Uncompleted handshake request was canceled
 *   %false - Handshake request already completed or not found
 */
bool handshake_req_cancel(struct sock *sk)
{
        struct handshake_req *req;
        struct handshake_net *hn;
        struct net *net;

        net = sock_net(sk);
        req = handshake_req_hash_lookup(sk);
        if (!req) {
                trace_handshake_cancel_none(net, req, sk);
                return false;
        }

        hn = handshake_pernet(net);
        if (hn && remove_pending(hn, req)) {
                /* Request hadn't been accepted - mark cancelled */
                if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
                        trace_handshake_cancel_busy(net, req, sk);
                        return false;
                }
                goto out_true;
        }
        if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
                /* Request already completed */
                trace_handshake_cancel_busy(net, req, sk);
                return false;
        }

out_true:
        trace_handshake_cancel(net, req, sk);

        /* Handshake request is no longer pending */
        sock_put(sk);
        return true;
}
EXPORT_SYMBOL(handshake_req_cancel);