root/net/sunrpc/auth_tls.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (c) 2021, 2022 Oracle.  All rights reserved.
 *
 * The AUTH_TLS credential is used only to probe a remote peer
 * for RPC-over-TLS support.
 */

#include <linux/types.h>
#include <linux/module.h>
#include <linux/sunrpc/clnt.h>

static const char *starttls_token = "STARTTLS";
static const size_t starttls_len = 8;

static struct rpc_auth tls_auth;
static struct rpc_cred tls_cred;

static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
                             const void *obj)
{
}

static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
                            void *obj)
{
        return 0;
}

static const struct rpc_procinfo rpcproc_tls_probe = {
        .p_encode       = tls_encode_probe,
        .p_decode       = tls_decode_probe,
};

static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
{
        task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
        rpc_call_start(task);
}

static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
{
}

static const struct rpc_call_ops rpc_tls_probe_ops = {
        .rpc_call_prepare       = rpc_tls_probe_call_prepare,
        .rpc_call_done          = rpc_tls_probe_call_done,
};

static int tls_probe(struct rpc_clnt *clnt)
{
        struct rpc_message msg = {
                .rpc_proc       = &rpcproc_tls_probe,
        };
        struct rpc_task_setup task_setup_data = {
                .rpc_client     = clnt,
                .rpc_message    = &msg,
                .rpc_op_cred    = &tls_cred,
                .callback_ops   = &rpc_tls_probe_ops,
                .flags          = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
        };
        struct rpc_task *task;
        int status;

        task = rpc_run_task(&task_setup_data);
        if (IS_ERR(task))
                return PTR_ERR(task);
        status = task->tk_status;
        rpc_put_task(task);
        return status;
}

static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
                                   struct rpc_clnt *clnt)
{
        refcount_inc(&tls_auth.au_count);
        return &tls_auth;
}

static void tls_destroy(struct rpc_auth *auth)
{
}

static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
                                        struct auth_cred *acred, int flags)
{
        return get_rpccred(&tls_cred);
}

static void tls_destroy_cred(struct rpc_cred *cred)
{
}

static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
{
        return 1;
}

static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
{
        __be32 *p;

        p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
        if (!p)
                return -EMSGSIZE;
        /* Credential */
        *p++ = rpc_auth_tls;
        *p++ = xdr_zero;
        /* Verifier */
        *p++ = rpc_auth_null;
        *p   = xdr_zero;
        return 0;
}

static int tls_refresh(struct rpc_task *task)
{
        set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
        return 0;
}

static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
{
        __be32 *p;
        void *str;

        p = xdr_inline_decode(xdr, XDR_UNIT);
        if (!p)
                return -EIO;
        if (*p != rpc_auth_null)
                return -EIO;
        if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
                return -EPROTONOSUPPORT;
        if (memcmp(str, starttls_token, starttls_len))
                return -EPROTONOSUPPORT;
        return 0;
}

const struct rpc_authops authtls_ops = {
        .owner          = THIS_MODULE,
        .au_flavor      = RPC_AUTH_TLS,
        .au_name        = "NULL",
        .create         = tls_create,
        .destroy        = tls_destroy,
        .lookup_cred    = tls_lookup_cred,
        .ping           = tls_probe,
};

static struct rpc_auth tls_auth = {
        .au_cslack      = NUL_CALLSLACK,
        .au_rslack      = NUL_REPLYSLACK,
        .au_verfsize    = NUL_REPLYSLACK,
        .au_ralign      = NUL_REPLYSLACK,
        .au_ops         = &authtls_ops,
        .au_flavor      = RPC_AUTH_TLS,
        .au_count       = REFCOUNT_INIT(1),
};

static const struct rpc_credops tls_credops = {
        .cr_name        = "AUTH_TLS",
        .crdestroy      = tls_destroy_cred,
        .crmatch        = tls_match,
        .crmarshal      = tls_marshal,
        .crwrap_req     = rpcauth_wrap_req_encode,
        .crrefresh      = tls_refresh,
        .crvalidate     = tls_validate,
        .crunwrap_resp  = rpcauth_unwrap_resp_decode,
};

static struct rpc_cred tls_cred = {
        .cr_lru         = LIST_HEAD_INIT(tls_cred.cr_lru),
        .cr_auth        = &tls_auth,
        .cr_ops         = &tls_credops,
        .cr_count       = REFCOUNT_INIT(2),
        .cr_flags       = 1UL << RPCAUTH_CRED_UPTODATE,
};