root/net/vmw_vsock/vmci_transport_notify.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * VMware vSockets Driver
 *
 * Copyright (C) 2009-2013 VMware, Inc. All rights reserved.
 */

#include <linux/types.h>
#include <linux/socket.h>
#include <linux/stddef.h>
#include <net/sock.h>

#include "vmci_transport_notify.h"

#define PKT_FIELD(vsk, field_name) (vmci_trans(vsk)->notify.pkt.field_name)

static bool vmci_transport_notify_waiting_write(struct vsock_sock *vsk)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        bool retval;
        u64 notify_limit;

        if (!PKT_FIELD(vsk, peer_waiting_write))
                return false;

#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL
        /* When the sender blocks, we take that as a sign that the sender is
         * faster than the receiver. To reduce the transmit rate of the sender,
         * we delay the sending of the read notification by decreasing the
         * write_notify_window. The notification is delayed until the number of
         * bytes used in the queue drops below the write_notify_window.
         */

        if (!PKT_FIELD(vsk, peer_waiting_write_detected)) {
                PKT_FIELD(vsk, peer_waiting_write_detected) = true;
                if (PKT_FIELD(vsk, write_notify_window) < PAGE_SIZE) {
                        PKT_FIELD(vsk, write_notify_window) =
                            PKT_FIELD(vsk, write_notify_min_window);
                } else {
                        PKT_FIELD(vsk, write_notify_window) -= PAGE_SIZE;
                        if (PKT_FIELD(vsk, write_notify_window) <
                            PKT_FIELD(vsk, write_notify_min_window))
                                PKT_FIELD(vsk, write_notify_window) =
                                    PKT_FIELD(vsk, write_notify_min_window);

                }
        }
        notify_limit = vmci_trans(vsk)->consume_size -
                PKT_FIELD(vsk, write_notify_window);
#else
        notify_limit = 0;
#endif

        /* For now we ignore the wait information and just see if the free
         * space exceeds the notify limit.  Note that improving this function
         * to be more intelligent will not require a protocol change and will
         * retain compatibility between endpoints with mixed versions of this
         * function.
         *
         * The notify_limit is used to delay notifications in the case where
         * flow control is enabled. Below the test is expressed in terms of
         * free space in the queue: if free_space > ConsumeSize -
         * write_notify_window then notify An alternate way of expressing this
         * is to rewrite the expression to use the data ready in the receive
         * queue: if write_notify_window > bufferReady then notify as
         * free_space == ConsumeSize - bufferReady.
         */
        retval = vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair) >
                notify_limit;
#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL
        if (retval) {
                /*
                 * Once we notify the peer, we reset the detected flag so the
                 * next wait will again cause a decrease in the window size.
                 */

                PKT_FIELD(vsk, peer_waiting_write_detected) = false;
        }
#endif
        return retval;
#else
        return true;
#endif
}

static bool vmci_transport_notify_waiting_read(struct vsock_sock *vsk)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        if (!PKT_FIELD(vsk, peer_waiting_read))
                return false;

        /* For now we ignore the wait information and just see if there is any
         * data for our peer to read.  Note that improving this function to be
         * more intelligent will not require a protocol change and will retain
         * compatibility between endpoints with mixed versions of this
         * function.
         */
        return vmci_qpair_produce_buf_ready(vmci_trans(vsk)->qpair) > 0;
#else
        return true;
#endif
}

static void
vmci_transport_handle_waiting_read(struct sock *sk,
                                   struct vmci_transport_packet *pkt,
                                   bool bottom_half,
                                   struct sockaddr_vm *dst,
                                   struct sockaddr_vm *src)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        struct vsock_sock *vsk;

        vsk = vsock_sk(sk);

        PKT_FIELD(vsk, peer_waiting_read) = true;
        memcpy(&PKT_FIELD(vsk, peer_waiting_read_info), &pkt->u.wait,
               sizeof(PKT_FIELD(vsk, peer_waiting_read_info)));

        if (vmci_transport_notify_waiting_read(vsk)) {
                bool sent;

                if (bottom_half)
                        sent = vmci_transport_send_wrote_bh(dst, src) > 0;
                else
                        sent = vmci_transport_send_wrote(sk) > 0;

                if (sent)
                        PKT_FIELD(vsk, peer_waiting_read) = false;
        }
#endif
}

static void
vmci_transport_handle_waiting_write(struct sock *sk,
                                    struct vmci_transport_packet *pkt,
                                    bool bottom_half,
                                    struct sockaddr_vm *dst,
                                    struct sockaddr_vm *src)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        struct vsock_sock *vsk;

        vsk = vsock_sk(sk);

        PKT_FIELD(vsk, peer_waiting_write) = true;
        memcpy(&PKT_FIELD(vsk, peer_waiting_write_info), &pkt->u.wait,
               sizeof(PKT_FIELD(vsk, peer_waiting_write_info)));

        if (vmci_transport_notify_waiting_write(vsk)) {
                bool sent;

                if (bottom_half)
                        sent = vmci_transport_send_read_bh(dst, src) > 0;
                else
                        sent = vmci_transport_send_read(sk) > 0;

                if (sent)
                        PKT_FIELD(vsk, peer_waiting_write) = false;
        }
#endif
}

static void
vmci_transport_handle_read(struct sock *sk,
                           struct vmci_transport_packet *pkt,
                           bool bottom_half,
                           struct sockaddr_vm *dst, struct sockaddr_vm *src)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        struct vsock_sock *vsk;

        vsk = vsock_sk(sk);
        PKT_FIELD(vsk, sent_waiting_write) = false;
#endif

        sk->sk_write_space(sk);
}

static bool send_waiting_read(struct sock *sk, u64 room_needed)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        struct vsock_sock *vsk;
        struct vmci_transport_waiting_info waiting_info;
        u64 tail;
        u64 head;
        u64 room_left;
        bool ret;

        vsk = vsock_sk(sk);

        if (PKT_FIELD(vsk, sent_waiting_read))
                return true;

        if (PKT_FIELD(vsk, write_notify_window) <
                        vmci_trans(vsk)->consume_size)
                PKT_FIELD(vsk, write_notify_window) =
                    min(PKT_FIELD(vsk, write_notify_window) + PAGE_SIZE,
                        vmci_trans(vsk)->consume_size);

        vmci_qpair_get_consume_indexes(vmci_trans(vsk)->qpair, &tail, &head);
        room_left = vmci_trans(vsk)->consume_size - head;
        if (room_needed >= room_left) {
                waiting_info.offset = room_needed - room_left;
                waiting_info.generation =
                    PKT_FIELD(vsk, consume_q_generation) + 1;
        } else {
                waiting_info.offset = head + room_needed;
                waiting_info.generation = PKT_FIELD(vsk, consume_q_generation);
        }

        ret = vmci_transport_send_waiting_read(sk, &waiting_info) > 0;
        if (ret)
                PKT_FIELD(vsk, sent_waiting_read) = true;

        return ret;
#else
        return true;
#endif
}

static bool send_waiting_write(struct sock *sk, u64 room_needed)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        struct vsock_sock *vsk;
        struct vmci_transport_waiting_info waiting_info;
        u64 tail;
        u64 head;
        u64 room_left;
        bool ret;

        vsk = vsock_sk(sk);

        if (PKT_FIELD(vsk, sent_waiting_write))
                return true;

        vmci_qpair_get_produce_indexes(vmci_trans(vsk)->qpair, &tail, &head);
        room_left = vmci_trans(vsk)->produce_size - tail;
        if (room_needed + 1 >= room_left) {
                /* Wraps around to current generation. */
                waiting_info.offset = room_needed + 1 - room_left;
                waiting_info.generation = PKT_FIELD(vsk, produce_q_generation);
        } else {
                waiting_info.offset = tail + room_needed + 1;
                waiting_info.generation =
                    PKT_FIELD(vsk, produce_q_generation) - 1;
        }

        ret = vmci_transport_send_waiting_write(sk, &waiting_info) > 0;
        if (ret)
                PKT_FIELD(vsk, sent_waiting_write) = true;

        return ret;
#else
        return true;
#endif
}

static int vmci_transport_send_read_notification(struct sock *sk)
{
        struct vsock_sock *vsk;
        bool sent_read;
        unsigned int retries;
        int err;

        vsk = vsock_sk(sk);
        sent_read = false;
        retries = 0;
        err = 0;

        if (vmci_transport_notify_waiting_write(vsk)) {
                /* Notify the peer that we have read, retrying the send on
                 * failure up to our maximum value.  XXX For now we just log
                 * the failure, but later we should schedule a work item to
                 * handle the resend until it succeeds.  That would require
                 * keeping track of work items in the vsk and cleaning them up
                 * upon socket close.
                 */
                while (!(vsk->peer_shutdown & RCV_SHUTDOWN) &&
                       !sent_read &&
                       retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) {
                        err = vmci_transport_send_read(sk);
                        if (err >= 0)
                                sent_read = true;

                        retries++;
                }

                if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS)
                        pr_err("%p unable to send read notify to peer\n", sk);
                else
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
                        PKT_FIELD(vsk, peer_waiting_write) = false;
#endif

        }
        return err;
}

static void
vmci_transport_handle_wrote(struct sock *sk,
                            struct vmci_transport_packet *pkt,
                            bool bottom_half,
                            struct sockaddr_vm *dst, struct sockaddr_vm *src)
{
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        struct vsock_sock *vsk = vsock_sk(sk);
        PKT_FIELD(vsk, sent_waiting_read) = false;
#endif
        vsock_data_ready(sk);
}

static void vmci_transport_notify_pkt_socket_init(struct sock *sk)
{
        struct vsock_sock *vsk = vsock_sk(sk);

        PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE;
        PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE;
        PKT_FIELD(vsk, peer_waiting_read) = false;
        PKT_FIELD(vsk, peer_waiting_write) = false;
        PKT_FIELD(vsk, peer_waiting_write_detected) = false;
        PKT_FIELD(vsk, sent_waiting_read) = false;
        PKT_FIELD(vsk, sent_waiting_write) = false;
        PKT_FIELD(vsk, produce_q_generation) = 0;
        PKT_FIELD(vsk, consume_q_generation) = 0;

        memset(&PKT_FIELD(vsk, peer_waiting_read_info), 0,
               sizeof(PKT_FIELD(vsk, peer_waiting_read_info)));
        memset(&PKT_FIELD(vsk, peer_waiting_write_info), 0,
               sizeof(PKT_FIELD(vsk, peer_waiting_write_info)));
}

static void vmci_transport_notify_pkt_socket_destruct(struct vsock_sock *vsk)
{
}

static int
vmci_transport_notify_pkt_poll_in(struct sock *sk,
                                  size_t target, bool *data_ready_now)
{
        struct vsock_sock *vsk = vsock_sk(sk);

        if (vsock_stream_has_data(vsk) >= target) {
                *data_ready_now = true;
        } else {
                /* We can't read right now because there is not enough data
                 * in the queue. Ask for notifications when there is something
                 * to read.
                 */
                if (sk->sk_state == TCP_ESTABLISHED) {
                        if (!send_waiting_read(sk, 1))
                                return -1;

                }
                *data_ready_now = false;
        }

        return 0;
}

static int
vmci_transport_notify_pkt_poll_out(struct sock *sk,
                                   size_t target, bool *space_avail_now)
{
        s64 produce_q_free_space;
        struct vsock_sock *vsk = vsock_sk(sk);

        produce_q_free_space = vsock_stream_has_space(vsk);
        if (produce_q_free_space > 0) {
                *space_avail_now = true;
                return 0;
        } else if (produce_q_free_space == 0) {
                /* This is a connected socket but we can't currently send data.
                 * Notify the peer that we are waiting if the queue is full. We
                 * only send a waiting write if the queue is full because
                 * otherwise we end up in an infinite WAITING_WRITE, READ,
                 * WAITING_WRITE, READ, etc. loop. Treat failing to send the
                 * notification as a socket error, passing that back through
                 * the mask.
                 */
                if (!send_waiting_write(sk, 1))
                        return -1;

                *space_avail_now = false;
        }

        return 0;
}

static int
vmci_transport_notify_pkt_recv_init(
                        struct sock *sk,
                        size_t target,
                        struct vmci_transport_recv_notify_data *data)
{
        struct vsock_sock *vsk = vsock_sk(sk);

#ifdef VSOCK_OPTIMIZATION_WAITING_NOTIFY
        data->consume_head = 0;
        data->produce_tail = 0;
#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL
        data->notify_on_block = false;

        if (PKT_FIELD(vsk, write_notify_min_window) < target + 1) {
                PKT_FIELD(vsk, write_notify_min_window) = target + 1;
                if (PKT_FIELD(vsk, write_notify_window) <
                    PKT_FIELD(vsk, write_notify_min_window)) {
                        /* If the current window is smaller than the new
                         * minimal window size, we need to reevaluate whether
                         * we need to notify the sender. If the number of ready
                         * bytes are smaller than the new window, we need to
                         * send a notification to the sender before we block.
                         */

                        PKT_FIELD(vsk, write_notify_window) =
                            PKT_FIELD(vsk, write_notify_min_window);
                        data->notify_on_block = true;
                }
        }
#endif
#endif

        return 0;
}

static int
vmci_transport_notify_pkt_recv_pre_block(
                                struct sock *sk,
                                size_t target,
                                struct vmci_transport_recv_notify_data *data)
{
        int err = 0;

        /* Notify our peer that we are waiting for data to read. */
        if (!send_waiting_read(sk, target)) {
                err = -EHOSTUNREACH;
                return err;
        }
#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL
        if (data->notify_on_block) {
                err = vmci_transport_send_read_notification(sk);
                if (err < 0)
                        return err;

                data->notify_on_block = false;
        }
#endif

        return err;
}

static int
vmci_transport_notify_pkt_recv_pre_dequeue(
                                struct sock *sk,
                                size_t target,
                                struct vmci_transport_recv_notify_data *data)
{
        struct vsock_sock *vsk = vsock_sk(sk);

        /* Now consume up to len bytes from the queue.  Note that since we have
         * the socket locked we should copy at least ready bytes.
         */
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        vmci_qpair_get_consume_indexes(vmci_trans(vsk)->qpair,
                                       &data->produce_tail,
                                       &data->consume_head);
#endif

        return 0;
}

static int
vmci_transport_notify_pkt_recv_post_dequeue(
                                struct sock *sk,
                                size_t target,
                                ssize_t copied,
                                bool data_read,
                                struct vmci_transport_recv_notify_data *data)
{
        struct vsock_sock *vsk;
        int err;

        vsk = vsock_sk(sk);
        err = 0;

        if (data_read) {
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
                /* Detect a wrap-around to maintain queue generation.  Note
                 * that this is safe since we hold the socket lock across the
                 * two queue pair operations.
                 */
                if (copied >=
                        vmci_trans(vsk)->consume_size - data->consume_head)
                        PKT_FIELD(vsk, consume_q_generation)++;
#endif

                err = vmci_transport_send_read_notification(sk);
                if (err < 0)
                        return err;

        }
        return err;
}

static int
vmci_transport_notify_pkt_send_init(
                        struct sock *sk,
                        struct vmci_transport_send_notify_data *data)
{
#ifdef VSOCK_OPTIMIZATION_WAITING_NOTIFY
        data->consume_head = 0;
        data->produce_tail = 0;
#endif

        return 0;
}

static int
vmci_transport_notify_pkt_send_pre_block(
                                struct sock *sk,
                                struct vmci_transport_send_notify_data *data)
{
        /* Notify our peer that we are waiting for room to write. */
        if (!send_waiting_write(sk, 1))
                return -EHOSTUNREACH;

        return 0;
}

static int
vmci_transport_notify_pkt_send_pre_enqueue(
                                struct sock *sk,
                                struct vmci_transport_send_notify_data *data)
{
        struct vsock_sock *vsk = vsock_sk(sk);

#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        vmci_qpair_get_produce_indexes(vmci_trans(vsk)->qpair,
                                       &data->produce_tail,
                                       &data->consume_head);
#endif

        return 0;
}

static int
vmci_transport_notify_pkt_send_post_enqueue(
                                struct sock *sk,
                                ssize_t written,
                                struct vmci_transport_send_notify_data *data)
{
        int err = 0;
        struct vsock_sock *vsk;
        bool sent_wrote = false;
        int retries = 0;

        vsk = vsock_sk(sk);

#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
        /* Detect a wrap-around to maintain queue generation.  Note that this
         * is safe since we hold the socket lock across the two queue pair
         * operations.
         */
        if (written >= vmci_trans(vsk)->produce_size - data->produce_tail)
                PKT_FIELD(vsk, produce_q_generation)++;

#endif

        if (vmci_transport_notify_waiting_read(vsk)) {
                /* Notify the peer that we have written, retrying the send on
                 * failure up to our maximum value. See the XXX comment for the
                 * corresponding piece of code in StreamRecvmsg() for potential
                 * improvements.
                 */
                while (!(vsk->peer_shutdown & RCV_SHUTDOWN) &&
                       !sent_wrote &&
                       retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) {
                        err = vmci_transport_send_wrote(sk);
                        if (err >= 0)
                                sent_wrote = true;

                        retries++;
                }

                if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS) {
                        pr_err("%p unable to send wrote notify to peer\n", sk);
                        return err;
                } else {
#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY)
                        PKT_FIELD(vsk, peer_waiting_read) = false;
#endif
                }
        }
        return err;
}

static void
vmci_transport_notify_pkt_handle_pkt(
                        struct sock *sk,
                        struct vmci_transport_packet *pkt,
                        bool bottom_half,
                        struct sockaddr_vm *dst,
                        struct sockaddr_vm *src, bool *pkt_processed)
{
        bool processed = false;

        switch (pkt->type) {
        case VMCI_TRANSPORT_PACKET_TYPE_WROTE:
                vmci_transport_handle_wrote(sk, pkt, bottom_half, dst, src);
                processed = true;
                break;
        case VMCI_TRANSPORT_PACKET_TYPE_READ:
                vmci_transport_handle_read(sk, pkt, bottom_half, dst, src);
                processed = true;
                break;
        case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE:
                vmci_transport_handle_waiting_write(sk, pkt, bottom_half,
                                                    dst, src);
                processed = true;
                break;

        case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ:
                vmci_transport_handle_waiting_read(sk, pkt, bottom_half,
                                                   dst, src);
                processed = true;
                break;
        }

        if (pkt_processed)
                *pkt_processed = processed;
}

static void vmci_transport_notify_pkt_process_request(struct sock *sk)
{
        struct vsock_sock *vsk = vsock_sk(sk);

        PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size;
        if (vmci_trans(vsk)->consume_size <
                PKT_FIELD(vsk, write_notify_min_window))
                PKT_FIELD(vsk, write_notify_min_window) =
                        vmci_trans(vsk)->consume_size;
}

static void vmci_transport_notify_pkt_process_negotiate(struct sock *sk)
{
        struct vsock_sock *vsk = vsock_sk(sk);

        PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size;
        if (vmci_trans(vsk)->consume_size <
                PKT_FIELD(vsk, write_notify_min_window))
                PKT_FIELD(vsk, write_notify_min_window) =
                        vmci_trans(vsk)->consume_size;
}

/* Socket control packet based operations. */
const struct vmci_transport_notify_ops vmci_transport_notify_pkt_ops = {
        .socket_init = vmci_transport_notify_pkt_socket_init,
        .socket_destruct = vmci_transport_notify_pkt_socket_destruct,
        .poll_in = vmci_transport_notify_pkt_poll_in,
        .poll_out = vmci_transport_notify_pkt_poll_out,
        .handle_notify_pkt = vmci_transport_notify_pkt_handle_pkt,
        .recv_init = vmci_transport_notify_pkt_recv_init,
        .recv_pre_block = vmci_transport_notify_pkt_recv_pre_block,
        .recv_pre_dequeue = vmci_transport_notify_pkt_recv_pre_dequeue,
        .recv_post_dequeue = vmci_transport_notify_pkt_recv_post_dequeue,
        .send_init = vmci_transport_notify_pkt_send_init,
        .send_pre_block = vmci_transport_notify_pkt_send_pre_block,
        .send_pre_enqueue = vmci_transport_notify_pkt_send_pre_enqueue,
        .send_post_enqueue = vmci_transport_notify_pkt_send_post_enqueue,
        .process_request = vmci_transport_notify_pkt_process_request,
        .process_negotiate = vmci_transport_notify_pkt_process_negotiate,
};