root/drivers/infiniband/hw/erdma/erdma_cm.c
// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause

/* Authors: Cheng Xu <chengyou@linux.alibaba.com> */
/*          Kai Shen <kaishen@linux.alibaba.com> */
/* Copyright (c) 2020-2022, Alibaba Group. */

/* Authors: Bernard Metzler <bmt@zurich.ibm.com> */
/*          Fredy Neeser */
/*          Greg Joyce <greg@opengridcomputing.com> */
/* Copyright (c) 2008-2019, IBM Corporation */
/* Copyright (c) 2017, Open Grid Computing, Inc. */

#include <linux/workqueue.h>
#include <trace/events/sock.h>

#include "erdma.h"
#include "erdma_cm.h"
#include "erdma_verbs.h"

static struct workqueue_struct *erdma_cm_wq;

static void erdma_cm_llp_state_change(struct sock *sk);
static void erdma_cm_llp_data_ready(struct sock *sk);
static void erdma_cm_llp_error_report(struct sock *sk);

static void erdma_sk_assign_cm_upcalls(struct sock *sk)
{
        write_lock_bh(&sk->sk_callback_lock);
        sk->sk_state_change = erdma_cm_llp_state_change;
        sk->sk_data_ready = erdma_cm_llp_data_ready;
        sk->sk_error_report = erdma_cm_llp_error_report;
        write_unlock_bh(&sk->sk_callback_lock);
}

static void erdma_sk_save_upcalls(struct sock *sk)
{
        struct erdma_cep *cep = sk_to_cep(sk);

        write_lock_bh(&sk->sk_callback_lock);
        cep->sk_state_change = sk->sk_state_change;
        cep->sk_data_ready = sk->sk_data_ready;
        cep->sk_error_report = sk->sk_error_report;
        write_unlock_bh(&sk->sk_callback_lock);
}

static void erdma_sk_restore_upcalls(struct sock *sk, struct erdma_cep *cep)
{
        sk->sk_state_change = cep->sk_state_change;
        sk->sk_data_ready = cep->sk_data_ready;
        sk->sk_error_report = cep->sk_error_report;
        sk->sk_user_data = NULL;
}

static void erdma_socket_disassoc(struct socket *s)
{
        struct sock *sk = s->sk;
        struct erdma_cep *cep;

        if (sk) {
                write_lock_bh(&sk->sk_callback_lock);
                cep = sk_to_cep(sk);
                if (cep) {
                        erdma_sk_restore_upcalls(sk, cep);
                        erdma_cep_put(cep);
                } else {
                        WARN_ON_ONCE(1);
                }
                write_unlock_bh(&sk->sk_callback_lock);
        } else {
                WARN_ON_ONCE(1);
        }
}

static void erdma_cep_socket_assoc(struct erdma_cep *cep, struct socket *s)
{
        cep->sock = s;
        erdma_cep_get(cep);
        s->sk->sk_user_data = cep;

        erdma_sk_save_upcalls(s->sk);
        erdma_sk_assign_cm_upcalls(s->sk);
}

static void erdma_disassoc_listen_cep(struct erdma_cep *cep)
{
        if (cep->listen_cep) {
                erdma_cep_put(cep->listen_cep);
                cep->listen_cep = NULL;
        }
}

static struct erdma_cep *erdma_cep_alloc(struct erdma_dev *dev)
{
        struct erdma_cep *cep = kzalloc_obj(*cep);
        unsigned long flags;

        if (!cep)
                return NULL;

        INIT_LIST_HEAD(&cep->listenq);
        INIT_LIST_HEAD(&cep->devq);
        INIT_LIST_HEAD(&cep->work_freelist);

        kref_init(&cep->ref);
        cep->state = ERDMA_EPSTATE_IDLE;
        init_waitqueue_head(&cep->waitq);
        spin_lock_init(&cep->lock);
        cep->dev = dev;

        spin_lock_irqsave(&dev->lock, flags);
        list_add_tail(&cep->devq, &dev->cep_list);
        spin_unlock_irqrestore(&dev->lock, flags);

        return cep;
}

static void erdma_cm_free_work(struct erdma_cep *cep)
{
        struct list_head *w, *tmp;
        struct erdma_cm_work *work;

        list_for_each_safe(w, tmp, &cep->work_freelist) {
                work = list_entry(w, struct erdma_cm_work, list);
                list_del(&work->list);
                kfree(work);
        }
}

static void erdma_cancel_mpatimer(struct erdma_cep *cep)
{
        spin_lock_bh(&cep->lock);
        if (cep->mpa_timer) {
                if (cancel_delayed_work(&cep->mpa_timer->work)) {
                        erdma_cep_put(cep);
                        kfree(cep->mpa_timer);
                }
                cep->mpa_timer = NULL;
        }
        spin_unlock_bh(&cep->lock);
}

static void erdma_put_work(struct erdma_cm_work *work)
{
        INIT_LIST_HEAD(&work->list);
        spin_lock_bh(&work->cep->lock);
        list_add(&work->list, &work->cep->work_freelist);
        spin_unlock_bh(&work->cep->lock);
}

static void erdma_cep_set_inuse(struct erdma_cep *cep)
{
        unsigned long flags;

        spin_lock_irqsave(&cep->lock, flags);
        while (cep->in_use) {
                spin_unlock_irqrestore(&cep->lock, flags);
                wait_event_interruptible(cep->waitq, !cep->in_use);
                if (signal_pending(current))
                        flush_signals(current);

                spin_lock_irqsave(&cep->lock, flags);
        }

        cep->in_use = 1;
        spin_unlock_irqrestore(&cep->lock, flags);
}

static void erdma_cep_set_free(struct erdma_cep *cep)
{
        unsigned long flags;

        spin_lock_irqsave(&cep->lock, flags);
        cep->in_use = 0;
        spin_unlock_irqrestore(&cep->lock, flags);

        wake_up(&cep->waitq);
}

static void __erdma_cep_dealloc(struct kref *ref)
{
        struct erdma_cep *cep = container_of(ref, struct erdma_cep, ref);
        struct erdma_dev *dev = cep->dev;
        unsigned long flags;

        WARN_ON(cep->listen_cep);

        kfree(cep->private_data);
        kfree(cep->mpa.pdata);
        spin_lock_bh(&cep->lock);
        if (!list_empty(&cep->work_freelist))
                erdma_cm_free_work(cep);
        spin_unlock_bh(&cep->lock);

        spin_lock_irqsave(&dev->lock, flags);
        list_del(&cep->devq);
        spin_unlock_irqrestore(&dev->lock, flags);
        kfree(cep);
}

static struct erdma_cm_work *erdma_get_work(struct erdma_cep *cep)
{
        struct erdma_cm_work *work = NULL;

        spin_lock_bh(&cep->lock);
        if (!list_empty(&cep->work_freelist)) {
                work = list_entry(cep->work_freelist.next, struct erdma_cm_work,
                                  list);
                list_del_init(&work->list);
        }

        spin_unlock_bh(&cep->lock);
        return work;
}

static int erdma_cm_alloc_work(struct erdma_cep *cep, int num)
{
        struct erdma_cm_work *work;

        while (num--) {
                work = kmalloc_obj(*work);
                if (!work) {
                        if (!(list_empty(&cep->work_freelist)))
                                erdma_cm_free_work(cep);
                        return -ENOMEM;
                }
                work->cep = cep;
                INIT_LIST_HEAD(&work->list);
                list_add(&work->list, &cep->work_freelist);
        }

        return 0;
}

static int erdma_cm_upcall(struct erdma_cep *cep, enum iw_cm_event_type reason,
                           int status)
{
        struct iw_cm_event event;
        struct iw_cm_id *cm_id;

        memset(&event, 0, sizeof(event));
        event.status = status;
        event.event = reason;

        if (reason == IW_CM_EVENT_CONNECT_REQUEST) {
                event.provider_data = cep;
                cm_id = cep->listen_cep->cm_id;

                event.ird = cep->dev->attrs.max_ird;
                event.ord = cep->dev->attrs.max_ord;
        } else {
                cm_id = cep->cm_id;
        }

        if (reason == IW_CM_EVENT_CONNECT_REQUEST ||
            reason == IW_CM_EVENT_CONNECT_REPLY) {
                u16 pd_len = be16_to_cpu(cep->mpa.hdr.params.pd_len);

                if (pd_len && cep->mpa.pdata) {
                        event.private_data_len = pd_len;
                        event.private_data = cep->mpa.pdata;
                }

                getname_local(cep->sock, &event.local_addr);
                getname_peer(cep->sock, &event.remote_addr);
        }

        return cm_id->event_handler(cm_id, &event);
}

void erdma_qp_cm_drop(struct erdma_qp *qp)
{
        struct erdma_cep *cep = qp->cep;

        if (!qp->cep)
                return;

        erdma_cep_set_inuse(cep);

        /* already closed. */
        if (cep->state == ERDMA_EPSTATE_CLOSED)
                goto out;

        if (cep->cm_id) {
                switch (cep->state) {
                case ERDMA_EPSTATE_AWAIT_MPAREP:
                        erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY,
                                        -EINVAL);
                        break;
                case ERDMA_EPSTATE_RDMA_MODE:
                        erdma_cm_upcall(cep, IW_CM_EVENT_CLOSE, 0);
                        break;
                case ERDMA_EPSTATE_IDLE:
                case ERDMA_EPSTATE_LISTENING:
                case ERDMA_EPSTATE_CONNECTING:
                case ERDMA_EPSTATE_AWAIT_MPAREQ:
                case ERDMA_EPSTATE_RECVD_MPAREQ:
                case ERDMA_EPSTATE_CLOSED:
                default:
                        break;
                }
                cep->cm_id->rem_ref(cep->cm_id);
                cep->cm_id = NULL;
                erdma_cep_put(cep);
        }
        cep->state = ERDMA_EPSTATE_CLOSED;

        if (cep->sock) {
                erdma_socket_disassoc(cep->sock);
                sock_release(cep->sock);
                cep->sock = NULL;
        }

        if (cep->qp) {
                cep->qp = NULL;
                erdma_qp_put(qp);
        }
out:
        erdma_cep_set_free(cep);
}

void erdma_cep_put(struct erdma_cep *cep)
{
        WARN_ON(kref_read(&cep->ref) < 1);
        kref_put(&cep->ref, __erdma_cep_dealloc);
}

void erdma_cep_get(struct erdma_cep *cep)
{
        kref_get(&cep->ref);
}

static int erdma_send_mpareqrep(struct erdma_cep *cep, const void *pdata,
                                u8 pd_len)
{
        struct socket *s = cep->sock;
        struct mpa_rr *rr = &cep->mpa.hdr;
        struct kvec iov[3];
        struct msghdr msg;
        int iovec_num = 0;
        int ret;
        int mpa_len;

        memset(&msg, 0, sizeof(msg));

        rr->params.pd_len = cpu_to_be16(pd_len);

        iov[iovec_num].iov_base = rr;
        iov[iovec_num].iov_len = sizeof(*rr);
        iovec_num++;
        mpa_len = sizeof(*rr);

        iov[iovec_num].iov_base = &cep->mpa.ext_data;
        iov[iovec_num].iov_len = sizeof(cep->mpa.ext_data);
        iovec_num++;
        mpa_len += sizeof(cep->mpa.ext_data);

        if (pd_len) {
                iov[iovec_num].iov_base = (char *)pdata;
                iov[iovec_num].iov_len = pd_len;
                mpa_len += pd_len;
                iovec_num++;
        }

        ret = kernel_sendmsg(s, &msg, iov, iovec_num, mpa_len);

        return ret < 0 ? ret : 0;
}

static inline int ksock_recv(struct socket *sock, char *buf, size_t size,
                             int flags)
{
        struct kvec iov = { buf, size };
        struct msghdr msg = { .msg_name = NULL, .msg_flags = flags };

        return kernel_recvmsg(sock, &msg, &iov, 1, size, flags);
}

static int __recv_mpa_hdr(struct erdma_cep *cep, int hdr_rcvd, char *hdr,
                          int hdr_size, int *rcvd_out)
{
        struct socket *s = cep->sock;
        int rcvd;

        *rcvd_out = 0;
        if (hdr_rcvd < hdr_size) {
                rcvd = ksock_recv(s, hdr + hdr_rcvd, hdr_size - hdr_rcvd,
                                  MSG_DONTWAIT);
                if (rcvd == -EAGAIN)
                        return -EAGAIN;

                if (rcvd <= 0)
                        return -ECONNABORTED;

                hdr_rcvd += rcvd;
                *rcvd_out = rcvd;

                if (hdr_rcvd < hdr_size)
                        return -EAGAIN;
        }

        return 0;
}

static void __mpa_rr_set_revision(__be16 *bits, u8 rev)
{
        *bits = (*bits & ~MPA_RR_MASK_REVISION) |
                (cpu_to_be16(rev) & MPA_RR_MASK_REVISION);
}

static u8 __mpa_rr_revision(__be16 mpa_rr_bits)
{
        __be16 rev = mpa_rr_bits & MPA_RR_MASK_REVISION;

        return (u8)be16_to_cpu(rev);
}

static void __mpa_ext_set_cc(__be32 *bits, u32 cc)
{
        *bits = (*bits & ~MPA_EXT_FLAG_CC) |
                (cpu_to_be32(cc) & MPA_EXT_FLAG_CC);
}

static u8 __mpa_ext_cc(__be32 mpa_ext_bits)
{
        __be32 cc = mpa_ext_bits & MPA_EXT_FLAG_CC;

        return (u8)be32_to_cpu(cc);
}

/*
 * Receive MPA Request/Reply header.
 *
 * Returns 0 if complete MPA Request/Reply haeder including
 * eventual private data was received. Returns -EAGAIN if
 * header was partially received or negative error code otherwise.
 *
 * Context: May be called in process context only
 */
static int erdma_recv_mpa_rr(struct erdma_cep *cep)
{
        struct mpa_rr *hdr = &cep->mpa.hdr;
        struct socket *s = cep->sock;
        u16 pd_len;
        int rcvd, to_rcv, ret, pd_rcvd;

        if (cep->mpa.bytes_rcvd < sizeof(struct mpa_rr)) {
                ret = __recv_mpa_hdr(cep, cep->mpa.bytes_rcvd,
                                     (char *)&cep->mpa.hdr,
                                     sizeof(struct mpa_rr), &rcvd);
                cep->mpa.bytes_rcvd += rcvd;
                if (ret)
                        return ret;
        }

        if (be16_to_cpu(hdr->params.pd_len) > MPA_MAX_PRIVDATA ||
            __mpa_rr_revision(hdr->params.bits) != MPA_REVISION_EXT_1)
                return -EPROTO;

        if (cep->mpa.bytes_rcvd - sizeof(struct mpa_rr) <
            sizeof(struct erdma_mpa_ext)) {
                ret = __recv_mpa_hdr(
                        cep, cep->mpa.bytes_rcvd - sizeof(struct mpa_rr),
                        (char *)&cep->mpa.ext_data,
                        sizeof(struct erdma_mpa_ext), &rcvd);
                cep->mpa.bytes_rcvd += rcvd;
                if (ret)
                        return ret;
        }

        pd_len = be16_to_cpu(hdr->params.pd_len);
        pd_rcvd = cep->mpa.bytes_rcvd - sizeof(struct mpa_rr) -
                  sizeof(struct erdma_mpa_ext);
        to_rcv = pd_len - pd_rcvd;

        if (!to_rcv) {
                /*
                 * We have received the whole MPA Request/Reply message.
                 * Check against peer protocol violation.
                 */
                u32 word;

                ret = __recv_mpa_hdr(cep, 0, (char *)&word, sizeof(word),
                                     &rcvd);
                if (ret == -EAGAIN && rcvd == 0)
                        return 0;

                if (ret)
                        return ret;

                return -EPROTO;
        }

        /*
         * At this point, MPA header has been fully received, and pd_len != 0.
         * So, begin to receive private data.
         */
        if (!cep->mpa.pdata) {
                cep->mpa.pdata = kmalloc(pd_len + 4, GFP_KERNEL);
                if (!cep->mpa.pdata)
                        return -ENOMEM;
        }

        rcvd = ksock_recv(s, cep->mpa.pdata + pd_rcvd, to_rcv + 4,
                          MSG_DONTWAIT);
        if (rcvd < 0)
                return rcvd;

        if (rcvd > to_rcv)
                return -EPROTO;

        cep->mpa.bytes_rcvd += rcvd;

        if (to_rcv == rcvd)
                return 0;

        return -EAGAIN;
}

/*
 * erdma_proc_mpareq()
 *
 * Read MPA Request from socket and signal new connection to IWCM
 * if success. Caller must hold lock on corresponding listening CEP.
 */
static int erdma_proc_mpareq(struct erdma_cep *cep)
{
        struct mpa_rr *req;
        int ret;

        ret = erdma_recv_mpa_rr(cep);
        if (ret)
                return ret;

        req = &cep->mpa.hdr;

        if (memcmp(req->key, MPA_KEY_REQ, MPA_KEY_SIZE))
                return -EPROTO;

        memcpy(req->key, MPA_KEY_REP, MPA_KEY_SIZE);

        /* Currently does not support marker and crc. */
        if (req->params.bits & MPA_RR_FLAG_MARKERS ||
            req->params.bits & MPA_RR_FLAG_CRC)
                goto reject_conn;

        cep->state = ERDMA_EPSTATE_RECVD_MPAREQ;

        /* Keep reference until IWCM accepts/rejects */
        erdma_cep_get(cep);
        ret = erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REQUEST, 0);
        if (ret)
                erdma_cep_put(cep);

        return ret;

reject_conn:
        req->params.bits &= ~MPA_RR_FLAG_MARKERS;
        req->params.bits |= MPA_RR_FLAG_REJECT;
        req->params.bits &= ~MPA_RR_FLAG_CRC;

        kfree(cep->mpa.pdata);
        cep->mpa.pdata = NULL;
        erdma_send_mpareqrep(cep, NULL, 0);

        return -EOPNOTSUPP;
}

static int erdma_proc_mpareply(struct erdma_cep *cep)
{
        enum erdma_qpa_mask_iwarp to_modify_attrs = 0;
        struct erdma_mod_qp_params_iwarp params;
        struct erdma_qp *qp = cep->qp;
        struct mpa_rr *rep;
        int ret;

        ret = erdma_recv_mpa_rr(cep);
        if (ret)
                goto out_err;

        erdma_cancel_mpatimer(cep);

        rep = &cep->mpa.hdr;

        if (memcmp(rep->key, MPA_KEY_REP, MPA_KEY_SIZE)) {
                ret = -EPROTO;
                goto out_err;
        }

        if (rep->params.bits & MPA_RR_FLAG_REJECT) {
                erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY, -ECONNRESET);
                return -ECONNRESET;
        }

        /* Currently does not support marker and crc. */
        if ((rep->params.bits & MPA_RR_FLAG_MARKERS) ||
            (rep->params.bits & MPA_RR_FLAG_CRC)) {
                erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY, -ECONNREFUSED);
                return -EINVAL;
        }

        memset(&params, 0, sizeof(params));
        params.state = ERDMA_QPS_IWARP_RTS;
        params.irq_size = cep->ird;
        params.orq_size = cep->ord;

        down_write(&qp->state_lock);
        if (qp->attrs.iwarp.state > ERDMA_QPS_IWARP_RTR) {
                ret = -EINVAL;
                up_write(&qp->state_lock);
                goto out_err;
        }

        to_modify_attrs = ERDMA_QPA_IWARP_STATE | ERDMA_QPA_IWARP_LLP_HANDLE |
                          ERDMA_QPA_IWARP_MPA | ERDMA_QPA_IWARP_IRD |
                          ERDMA_QPA_IWARP_ORD;

        params.qp_type = ERDMA_QP_ACTIVE;
        if (__mpa_ext_cc(cep->mpa.ext_data.bits) != qp->attrs.cc) {
                to_modify_attrs |= ERDMA_QPA_IWARP_CC;
                params.cc = COMPROMISE_CC;
        }

        ret = erdma_modify_qp_state_iwarp(qp, &params, to_modify_attrs);

        up_write(&qp->state_lock);

        if (!ret) {
                ret = erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY, 0);
                if (!ret)
                        cep->state = ERDMA_EPSTATE_RDMA_MODE;

                return 0;
        }

out_err:
        if (ret != -EAGAIN)
                erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY, -EINVAL);

        return ret;
}

static void erdma_accept_newconn(struct erdma_cep *cep)
{
        struct socket *s = cep->sock;
        struct socket *new_s = NULL;
        struct erdma_cep *new_cep = NULL;
        int ret = 0;

        if (cep->state != ERDMA_EPSTATE_LISTENING)
                goto error;

        new_cep = erdma_cep_alloc(cep->dev);
        if (!new_cep)
                goto error;

        /*
         * 4: Allocate a sufficient number of work elements
         * to allow concurrent handling of local + peer close
         * events, MPA header processing + MPA timeout.
         */
        if (erdma_cm_alloc_work(new_cep, 4) != 0)
                goto error;

        /*
         * Copy saved socket callbacks from listening CEP
         * and assign new socket with new CEP
         */
        new_cep->sk_state_change = cep->sk_state_change;
        new_cep->sk_data_ready = cep->sk_data_ready;
        new_cep->sk_error_report = cep->sk_error_report;

        ret = kernel_accept(s, &new_s, O_NONBLOCK);
        if (ret != 0)
                goto error;

        new_cep->sock = new_s;
        erdma_cep_get(new_cep);
        new_s->sk->sk_user_data = new_cep;

        tcp_sock_set_nodelay(new_s->sk);
        new_cep->state = ERDMA_EPSTATE_AWAIT_MPAREQ;

        ret = erdma_cm_queue_work(new_cep, ERDMA_CM_WORK_MPATIMEOUT);
        if (ret)
                goto error;

        new_cep->listen_cep = cep;
        erdma_cep_get(cep);

        if (atomic_read(&new_s->sk->sk_rmem_alloc)) {
                /* MPA REQ already queued */
                erdma_cep_set_inuse(new_cep);
                ret = erdma_proc_mpareq(new_cep);
                if (ret != -EAGAIN) {
                        erdma_cep_put(cep);
                        new_cep->listen_cep = NULL;
                        if (ret) {
                                erdma_cep_set_free(new_cep);
                                goto error;
                        }
                }
                erdma_cep_set_free(new_cep);
        }
        return;

error:
        if (new_cep) {
                new_cep->state = ERDMA_EPSTATE_CLOSED;
                erdma_cancel_mpatimer(new_cep);

                erdma_cep_put(new_cep);
        }

        if (new_s) {
                erdma_socket_disassoc(new_s);
                sock_release(new_s);
        }
}

static int erdma_newconn_connected(struct erdma_cep *cep)
{
        int ret = 0;

        cep->mpa.hdr.params.bits = 0;
        __mpa_rr_set_revision(&cep->mpa.hdr.params.bits, MPA_REVISION_EXT_1);

        memcpy(cep->mpa.hdr.key, MPA_KEY_REQ, MPA_KEY_SIZE);
        cep->mpa.ext_data.cookie = cpu_to_be32(cep->qp->attrs.iwarp.cookie);
        __mpa_ext_set_cc(&cep->mpa.ext_data.bits, cep->qp->attrs.cc);

        ret = erdma_send_mpareqrep(cep, cep->private_data, cep->pd_len);
        cep->state = ERDMA_EPSTATE_AWAIT_MPAREP;
        cep->mpa.hdr.params.pd_len = 0;

        if (ret >= 0)
                ret = erdma_cm_queue_work(cep, ERDMA_CM_WORK_MPATIMEOUT);

        return ret;
}

static void erdma_cm_work_handler(struct work_struct *w)
{
        struct erdma_cm_work *work;
        struct erdma_cep *cep;
        int release_cep = 0, ret = 0;

        work = container_of(w, struct erdma_cm_work, work.work);
        cep = work->cep;

        erdma_cep_set_inuse(cep);

        switch (work->type) {
        case ERDMA_CM_WORK_CONNECTED:
                erdma_cancel_mpatimer(cep);
                if (cep->state == ERDMA_EPSTATE_CONNECTING) {
                        ret = erdma_newconn_connected(cep);
                        if (ret) {
                                erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY,
                                                -EIO);
                                release_cep = 1;
                        }
                }
                break;
        case ERDMA_CM_WORK_CONNECTTIMEOUT:
                if (cep->state == ERDMA_EPSTATE_CONNECTING) {
                        cep->mpa_timer = NULL;
                        erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY,
                                        -ETIMEDOUT);
                        release_cep = 1;
                }
                break;
        case ERDMA_CM_WORK_ACCEPT:
                erdma_accept_newconn(cep);
                break;
        case ERDMA_CM_WORK_READ_MPAHDR:
                if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREQ) {
                        if (cep->listen_cep) {
                                erdma_cep_set_inuse(cep->listen_cep);

                                if (cep->listen_cep->state ==
                                    ERDMA_EPSTATE_LISTENING)
                                        ret = erdma_proc_mpareq(cep);
                                else
                                        ret = -EFAULT;

                                erdma_cep_set_free(cep->listen_cep);

                                if (ret != -EAGAIN) {
                                        erdma_cep_put(cep->listen_cep);
                                        cep->listen_cep = NULL;
                                        if (ret)
                                                erdma_cep_put(cep);
                                }
                        }
                } else if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREP) {
                        ret = erdma_proc_mpareply(cep);
                }

                if (ret && ret != -EAGAIN)
                        release_cep = 1;
                break;
        case ERDMA_CM_WORK_CLOSE_LLP:
                if (cep->cm_id)
                        erdma_cm_upcall(cep, IW_CM_EVENT_CLOSE, 0);
                release_cep = 1;
                break;
        case ERDMA_CM_WORK_PEER_CLOSE:
                if (cep->cm_id) {
                        if (cep->state == ERDMA_EPSTATE_CONNECTING ||
                            cep->state == ERDMA_EPSTATE_AWAIT_MPAREP) {
                                /*
                                 * MPA reply not received, but connection drop
                                 */
                                erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY,
                                                -ECONNRESET);
                        } else if (cep->state == ERDMA_EPSTATE_RDMA_MODE) {
                                /*
                                 * NOTE: IW_CM_EVENT_DISCONNECT is given just
                                 *       to transition IWCM into CLOSING.
                                 */
                                erdma_cm_upcall(cep, IW_CM_EVENT_DISCONNECT, 0);
                                erdma_cm_upcall(cep, IW_CM_EVENT_CLOSE, 0);
                        }
                } else if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREQ) {
                        /* Socket close before MPA request received. */
                        erdma_disassoc_listen_cep(cep);
                        erdma_cep_put(cep);
                }
                release_cep = 1;
                break;
        case ERDMA_CM_WORK_MPATIMEOUT:
                cep->mpa_timer = NULL;
                if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREP) {
                        /*
                         * MPA request timed out:
                         * Hide any partially received private data and signal
                         * timeout
                         */
                        cep->mpa.hdr.params.pd_len = 0;

                        if (cep->cm_id)
                                erdma_cm_upcall(cep, IW_CM_EVENT_CONNECT_REPLY,
                                                -ETIMEDOUT);
                        release_cep = 1;
                } else if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREQ) {
                        /* No MPA req received after peer TCP stream setup. */
                        erdma_disassoc_listen_cep(cep);

                        erdma_cep_put(cep);
                        release_cep = 1;
                }
                break;
        default:
                WARN(1, "Undefined CM work type: %d\n", work->type);
        }

        if (release_cep) {
                erdma_cancel_mpatimer(cep);
                cep->state = ERDMA_EPSTATE_CLOSED;
                if (cep->qp) {
                        struct erdma_qp *qp = cep->qp;
                        /*
                         * Serialize a potential race with application
                         * closing the QP and calling erdma_qp_cm_drop()
                         */
                        erdma_qp_get(qp);
                        erdma_cep_set_free(cep);

                        erdma_qp_llp_close(qp);
                        erdma_qp_put(qp);

                        erdma_cep_set_inuse(cep);
                        cep->qp = NULL;
                        erdma_qp_put(qp);
                }

                if (cep->sock) {
                        erdma_socket_disassoc(cep->sock);
                        sock_release(cep->sock);
                        cep->sock = NULL;
                }

                if (cep->cm_id) {
                        cep->cm_id->rem_ref(cep->cm_id);
                        cep->cm_id = NULL;
                        if (cep->state != ERDMA_EPSTATE_LISTENING)
                                erdma_cep_put(cep);
                }
        }
        erdma_cep_set_free(cep);
        erdma_put_work(work);
        erdma_cep_put(cep);
}

int erdma_cm_queue_work(struct erdma_cep *cep, enum erdma_work_type type)
{
        struct erdma_cm_work *work = erdma_get_work(cep);
        unsigned long delay = 0;

        if (!work)
                return -ENOMEM;

        work->type = type;
        work->cep = cep;

        erdma_cep_get(cep);

        INIT_DELAYED_WORK(&work->work, erdma_cm_work_handler);

        if (type == ERDMA_CM_WORK_MPATIMEOUT) {
                cep->mpa_timer = work;

                if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREP)
                        delay = MPAREP_TIMEOUT;
                else
                        delay = MPAREQ_TIMEOUT;
        } else if (type == ERDMA_CM_WORK_CONNECTTIMEOUT) {
                cep->mpa_timer = work;

                delay = CONNECT_TIMEOUT;
        }

        queue_delayed_work(erdma_cm_wq, &work->work, delay);

        return 0;
}

static void erdma_cm_llp_data_ready(struct sock *sk)
{
        struct erdma_cep *cep;

        trace_sk_data_ready(sk);

        read_lock(&sk->sk_callback_lock);

        cep = sk_to_cep(sk);
        if (!cep)
                goto out;

        if (cep->state == ERDMA_EPSTATE_AWAIT_MPAREQ ||
            cep->state == ERDMA_EPSTATE_AWAIT_MPAREP)
                erdma_cm_queue_work(cep, ERDMA_CM_WORK_READ_MPAHDR);

out:
        read_unlock(&sk->sk_callback_lock);
}

static void erdma_cm_llp_error_report(struct sock *sk)
{
        struct erdma_cep *cep = sk_to_cep(sk);

        if (cep)
                cep->sk_error_report(sk);
}

static void erdma_cm_llp_state_change(struct sock *sk)
{
        struct erdma_cep *cep;
        void (*orig_state_change)(struct sock *sk);

        read_lock(&sk->sk_callback_lock);

        cep = sk_to_cep(sk);
        if (!cep) {
                read_unlock(&sk->sk_callback_lock);
                return;
        }
        orig_state_change = cep->sk_state_change;

        switch (sk->sk_state) {
        case TCP_ESTABLISHED:
                if (cep->state == ERDMA_EPSTATE_CONNECTING)
                        erdma_cm_queue_work(cep, ERDMA_CM_WORK_CONNECTED);
                else
                        erdma_cm_queue_work(cep, ERDMA_CM_WORK_ACCEPT);
                break;
        case TCP_CLOSE:
        case TCP_CLOSE_WAIT:
                if (cep->state != ERDMA_EPSTATE_LISTENING)
                        erdma_cm_queue_work(cep, ERDMA_CM_WORK_PEER_CLOSE);
                break;
        default:
                break;
        }
        read_unlock(&sk->sk_callback_lock);
        orig_state_change(sk);
}

static int kernel_bindconnect(struct socket *s, struct sockaddr *laddr,
                              int laddrlen, struct sockaddr *raddr,
                              int raddrlen, int flags)
{
        int ret;

        sock_set_reuseaddr(s->sk);
        ret = s->ops->bind(s, (struct sockaddr_unsized *)laddr, laddrlen);
        if (ret)
                return ret;
        ret = s->ops->connect(s, (struct sockaddr_unsized *)raddr, raddrlen, flags);
        return ret < 0 ? ret : 0;
}

int erdma_connect(struct iw_cm_id *id, struct iw_cm_conn_param *params)
{
        struct erdma_dev *dev = to_edev(id->device);
        struct erdma_qp *qp;
        struct erdma_cep *cep = NULL;
        struct socket *s = NULL;
        struct sockaddr *laddr = (struct sockaddr *)&id->m_local_addr;
        struct sockaddr *raddr = (struct sockaddr *)&id->m_remote_addr;
        u16 pd_len = params->private_data_len;
        int ret;

        if (pd_len > MPA_MAX_PRIVDATA)
                return -EINVAL;

        if (params->ird > dev->attrs.max_ird ||
            params->ord > dev->attrs.max_ord)
                return -EINVAL;

        if (laddr->sa_family != AF_INET || raddr->sa_family != AF_INET)
                return -EAFNOSUPPORT;

        qp = find_qp_by_qpn(dev, params->qpn);
        if (!qp)
                return -ENOENT;
        erdma_qp_get(qp);

        ret = sock_create(AF_INET, SOCK_STREAM, IPPROTO_TCP, &s);
        if (ret < 0)
                goto error_put_qp;

        cep = erdma_cep_alloc(dev);
        if (!cep) {
                ret = -ENOMEM;
                goto error_release_sock;
        }

        erdma_cep_set_inuse(cep);

        /* Associate QP with CEP */
        erdma_cep_get(cep);
        qp->cep = cep;
        cep->qp = qp;

        /* Associate cm_id with CEP */
        id->add_ref(id);
        cep->cm_id = id;

        /*
         * 6: Allocate a sufficient number of work elements
         * to allow concurrent handling of local + peer close
         * events, MPA header processing + MPA timeout, connected event
         * and connect timeout.
         */
        ret = erdma_cm_alloc_work(cep, 6);
        if (ret != 0) {
                ret = -ENOMEM;
                goto error_release_cep;
        }

        cep->ird = params->ird;
        cep->ord = params->ord;
        cep->state = ERDMA_EPSTATE_CONNECTING;

        erdma_cep_socket_assoc(cep, s);

        if (pd_len) {
                cep->pd_len = pd_len;
                cep->private_data = kmalloc(pd_len, GFP_KERNEL);
                if (!cep->private_data) {
                        ret = -ENOMEM;
                        goto error_disassoc;
                }

                memcpy(cep->private_data, params->private_data,
                       params->private_data_len);
        }

        ret = kernel_bindconnect(s, laddr, sizeof(*laddr), raddr,
                                 sizeof(*raddr), O_NONBLOCK);
        if (ret != -EINPROGRESS && ret != 0) {
                goto error_disassoc;
        } else if (ret == 0) {
                ret = erdma_cm_queue_work(cep, ERDMA_CM_WORK_CONNECTED);
                if (ret)
                        goto error_disassoc;
        } else {
                ret = erdma_cm_queue_work(cep, ERDMA_CM_WORK_CONNECTTIMEOUT);
                if (ret)
                        goto error_disassoc;
        }

        erdma_cep_set_free(cep);
        return 0;

error_disassoc:
        kfree(cep->private_data);
        cep->private_data = NULL;
        cep->pd_len = 0;

        erdma_socket_disassoc(s);

error_release_cep:
        /* disassoc with cm_id */
        cep->cm_id = NULL;
        id->rem_ref(id);

        /* disassoc with qp */
        qp->cep = NULL;
        erdma_cep_put(cep);
        cep->qp = NULL;

        cep->state = ERDMA_EPSTATE_CLOSED;

        erdma_cep_set_free(cep);

        /* release the cep. */
        erdma_cep_put(cep);

error_release_sock:
        if (s)
                sock_release(s);
error_put_qp:
        erdma_qp_put(qp);

        return ret;
}

int erdma_accept(struct iw_cm_id *id, struct iw_cm_conn_param *params)
{
        struct erdma_cep *cep = (struct erdma_cep *)id->provider_data;
        struct erdma_mod_qp_params_iwarp mod_qp_params;
        enum erdma_qpa_mask_iwarp to_modify_attrs = 0;
        struct erdma_dev *dev = to_edev(id->device);
        struct erdma_qp *qp;
        int ret;

        erdma_cep_set_inuse(cep);
        erdma_cep_put(cep);

        /* Free lingering inbound private data */
        if (cep->mpa.hdr.params.pd_len) {
                cep->mpa.hdr.params.pd_len = 0;
                kfree(cep->mpa.pdata);
                cep->mpa.pdata = NULL;
        }
        erdma_cancel_mpatimer(cep);

        if (cep->state != ERDMA_EPSTATE_RECVD_MPAREQ) {
                erdma_cep_set_free(cep);
                erdma_cep_put(cep);

                return -ECONNRESET;
        }

        qp = find_qp_by_qpn(dev, params->qpn);
        if (!qp)
                return -ENOENT;
        erdma_qp_get(qp);

        down_write(&qp->state_lock);
        if (qp->attrs.iwarp.state > ERDMA_QPS_IWARP_RTR) {
                ret = -EINVAL;
                up_write(&qp->state_lock);
                goto error;
        }

        if (params->ord > dev->attrs.max_ord ||
            params->ird > dev->attrs.max_ord) {
                ret = -EINVAL;
                up_write(&qp->state_lock);
                goto error;
        }

        if (params->private_data_len > MPA_MAX_PRIVDATA) {
                ret = -EINVAL;
                up_write(&qp->state_lock);
                goto error;
        }

        cep->ird = params->ird;
        cep->ord = params->ord;

        cep->cm_id = id;
        id->add_ref(id);

        memset(&mod_qp_params, 0, sizeof(mod_qp_params));

        mod_qp_params.irq_size = params->ird;
        mod_qp_params.orq_size = params->ord;
        mod_qp_params.state = ERDMA_QPS_IWARP_RTS;

        /* Associate QP with CEP */
        erdma_cep_get(cep);
        qp->cep = cep;
        cep->qp = qp;

        cep->state = ERDMA_EPSTATE_RDMA_MODE;

        mod_qp_params.qp_type = ERDMA_QP_PASSIVE;
        mod_qp_params.pd_len = params->private_data_len;

        to_modify_attrs = ERDMA_QPA_IWARP_STATE | ERDMA_QPA_IWARP_ORD |
                          ERDMA_QPA_IWARP_LLP_HANDLE | ERDMA_QPA_IWARP_IRD |
                          ERDMA_QPA_IWARP_MPA;

        if (qp->attrs.cc != __mpa_ext_cc(cep->mpa.ext_data.bits)) {
                to_modify_attrs |= ERDMA_QPA_IWARP_CC;
                mod_qp_params.cc = COMPROMISE_CC;
        }

        /* move to rts */
        ret = erdma_modify_qp_state_iwarp(qp, &mod_qp_params, to_modify_attrs);

        up_write(&qp->state_lock);

        if (ret)
                goto error;

        cep->mpa.ext_data.bits = 0;
        __mpa_ext_set_cc(&cep->mpa.ext_data.bits, qp->attrs.cc);
        cep->mpa.ext_data.cookie = cpu_to_be32(cep->qp->attrs.iwarp.cookie);

        ret = erdma_send_mpareqrep(cep, params->private_data,
                                   params->private_data_len);
        if (!ret) {
                ret = erdma_cm_upcall(cep, IW_CM_EVENT_ESTABLISHED, 0);
                if (ret)
                        goto error;

                erdma_cep_set_free(cep);

                return 0;
        }

error:
        erdma_socket_disassoc(cep->sock);
        sock_release(cep->sock);
        cep->sock = NULL;

        cep->state = ERDMA_EPSTATE_CLOSED;

        if (cep->cm_id) {
                cep->cm_id->rem_ref(id);
                cep->cm_id = NULL;
        }

        if (qp->cep) {
                erdma_cep_put(cep);
                qp->cep = NULL;
        }

        cep->qp = NULL;
        erdma_qp_put(qp);

        erdma_cep_set_free(cep);
        erdma_cep_put(cep);

        return ret;
}

int erdma_reject(struct iw_cm_id *id, const void *pdata, u8 plen)
{
        struct erdma_cep *cep = (struct erdma_cep *)id->provider_data;

        erdma_cep_set_inuse(cep);
        erdma_cep_put(cep);

        erdma_cancel_mpatimer(cep);

        if (cep->state != ERDMA_EPSTATE_RECVD_MPAREQ) {
                erdma_cep_set_free(cep);
                erdma_cep_put(cep);

                return -ECONNRESET;
        }

        if (__mpa_rr_revision(cep->mpa.hdr.params.bits) == MPA_REVISION_EXT_1) {
                cep->mpa.hdr.params.bits |= MPA_RR_FLAG_REJECT; /* reject */
                erdma_send_mpareqrep(cep, pdata, plen);
        }

        erdma_socket_disassoc(cep->sock);
        sock_release(cep->sock);
        cep->sock = NULL;

        cep->state = ERDMA_EPSTATE_CLOSED;

        erdma_cep_set_free(cep);
        erdma_cep_put(cep);

        return 0;
}

int erdma_create_listen(struct iw_cm_id *id, int backlog)
{
        struct socket *s;
        struct erdma_cep *cep = NULL;
        int ret = 0;
        struct erdma_dev *dev = to_edev(id->device);
        int addr_family = id->local_addr.ss_family;
        struct sockaddr_in *laddr = &to_sockaddr_in(id->local_addr);

        if (addr_family != AF_INET)
                return -EAFNOSUPPORT;

        ret = sock_create(addr_family, SOCK_STREAM, IPPROTO_TCP, &s);
        if (ret < 0)
                return ret;

        sock_set_reuseaddr(s->sk);

        /* For wildcard addr, limit binding to current device only */
        if (ipv4_is_zeronet(laddr->sin_addr.s_addr))
                s->sk->sk_bound_dev_if = dev->netdev->ifindex;

        ret = s->ops->bind(s, (struct sockaddr_unsized *)laddr,
                           sizeof(struct sockaddr_in));
        if (ret)
                goto error;

        cep = erdma_cep_alloc(dev);
        if (!cep) {
                ret = -ENOMEM;
                goto error;
        }
        erdma_cep_socket_assoc(cep, s);

        ret = erdma_cm_alloc_work(cep, backlog);
        if (ret)
                goto error;

        ret = s->ops->listen(s, backlog);
        if (ret)
                goto error;

        cep->cm_id = id;
        id->add_ref(id);

        if (!id->provider_data) {
                id->provider_data =
                        kmalloc_obj(struct list_head);
                if (!id->provider_data) {
                        ret = -ENOMEM;
                        goto error;
                }
                INIT_LIST_HEAD((struct list_head *)id->provider_data);
        }

        list_add_tail(&cep->listenq, (struct list_head *)id->provider_data);
        cep->state = ERDMA_EPSTATE_LISTENING;

        return 0;

error:
        if (cep) {
                erdma_cep_set_inuse(cep);

                if (cep->cm_id) {
                        cep->cm_id->rem_ref(cep->cm_id);
                        cep->cm_id = NULL;
                }
                cep->sock = NULL;
                erdma_socket_disassoc(s);
                cep->state = ERDMA_EPSTATE_CLOSED;

                erdma_cep_set_free(cep);
                erdma_cep_put(cep);
        }
        sock_release(s);

        return ret;
}

static void erdma_drop_listeners(struct iw_cm_id *id)
{
        struct list_head *p, *tmp;
        /*
         * In case of a wildcard rdma_listen on a multi-homed device,
         * a listener's IWCM id is associated with more than one listening CEP.
         */
        list_for_each_safe(p, tmp, (struct list_head *)id->provider_data) {
                struct erdma_cep *cep =
                        list_entry(p, struct erdma_cep, listenq);

                list_del(p);

                erdma_cep_set_inuse(cep);

                if (cep->cm_id) {
                        cep->cm_id->rem_ref(cep->cm_id);
                        cep->cm_id = NULL;
                }
                if (cep->sock) {
                        erdma_socket_disassoc(cep->sock);
                        sock_release(cep->sock);
                        cep->sock = NULL;
                }
                cep->state = ERDMA_EPSTATE_CLOSED;
                erdma_cep_set_free(cep);
                erdma_cep_put(cep);
        }
}

int erdma_destroy_listen(struct iw_cm_id *id)
{
        if (!id->provider_data)
                return 0;

        erdma_drop_listeners(id);
        kfree(id->provider_data);
        id->provider_data = NULL;

        return 0;
}

int erdma_cm_init(void)
{
        erdma_cm_wq = create_singlethread_workqueue("erdma_cm_wq");
        if (!erdma_cm_wq)
                return -ENOMEM;

        return 0;
}

void erdma_cm_exit(void)
{
        if (erdma_cm_wq)
                destroy_workqueue(erdma_cm_wq);
}