root/drivers/infiniband/ulp/rtrs/rtrs-srv.c
// SPDX-License-Identifier: GPL-2.0-or-later
/*
 * RDMA Transport Layer
 *
 * Copyright (c) 2014 - 2018 ProfitBricks GmbH. All rights reserved.
 * Copyright (c) 2018 - 2019 1&1 IONOS Cloud GmbH. All rights reserved.
 * Copyright (c) 2019 - 2020 1&1 IONOS SE. All rights reserved.
 */

#undef pr_fmt
#define pr_fmt(fmt) KBUILD_MODNAME " L" __stringify(__LINE__) ": " fmt

#include <linux/module.h>

#include "rtrs-srv.h"
#include "rtrs-log.h"
#include <rdma/ib_cm.h>
#include <rdma/ib_verbs.h>
#include "rtrs-srv-trace.h"

MODULE_DESCRIPTION("RDMA Transport Server");
MODULE_LICENSE("GPL");

/* Must be power of 2, see mask from mr->page_size in ib_sg_to_pages() */
#define DEFAULT_MAX_CHUNK_SIZE (128 << 10)
#define DEFAULT_SESS_QUEUE_DEPTH 512
#define MAX_HDR_SIZE PAGE_SIZE

static const struct rtrs_rdma_dev_pd_ops dev_pd_ops;
static struct rtrs_rdma_dev_pd dev_pd = {
        .ops = &dev_pd_ops
};
const struct class rtrs_dev_class = {
        .name = "rtrs-server",
};
static struct rtrs_srv_ib_ctx ib_ctx;

static int __read_mostly max_chunk_size = DEFAULT_MAX_CHUNK_SIZE;
static int __read_mostly sess_queue_depth = DEFAULT_SESS_QUEUE_DEPTH;

static bool always_invalidate = true;
module_param(always_invalidate, bool, 0444);
MODULE_PARM_DESC(always_invalidate,
                 "Invalidate memory registration for contiguous memory regions before accessing.");

module_param_named(max_chunk_size, max_chunk_size, int, 0444);
MODULE_PARM_DESC(max_chunk_size,
                 "Max size for each IO request, when change the unit is in byte (default: "
                 __stringify(DEFAULT_MAX_CHUNK_SIZE) "KB)");

module_param_named(sess_queue_depth, sess_queue_depth, int, 0444);
MODULE_PARM_DESC(sess_queue_depth,
                 "Number of buffers for pending I/O requests to allocate per session. Maximum: "
                 __stringify(MAX_SESS_QUEUE_DEPTH) " (default: "
                 __stringify(DEFAULT_SESS_QUEUE_DEPTH) ")");

static cpumask_t cq_affinity_mask = { CPU_BITS_ALL };

static struct workqueue_struct *rtrs_wq;

static inline struct rtrs_srv_con *to_srv_con(struct rtrs_con *c)
{
        return container_of(c, struct rtrs_srv_con, c);
}

static bool rtrs_srv_change_state(struct rtrs_srv_path *srv_path,
                                  enum rtrs_srv_state new_state)
{
        enum rtrs_srv_state old_state;
        bool changed = false;
        unsigned long flags;

        spin_lock_irqsave(&srv_path->state_lock, flags);
        old_state = srv_path->state;
        switch (new_state) {
        case RTRS_SRV_CONNECTED:
                if (old_state == RTRS_SRV_CONNECTING)
                        changed = true;
                break;
        case RTRS_SRV_CLOSING:
                if (old_state == RTRS_SRV_CONNECTING ||
                    old_state == RTRS_SRV_CONNECTED)
                        changed = true;
                break;
        case RTRS_SRV_CLOSED:
                if (old_state == RTRS_SRV_CLOSING)
                        changed = true;
                break;
        default:
                break;
        }
        if (changed)
                srv_path->state = new_state;
        spin_unlock_irqrestore(&srv_path->state_lock, flags);

        return changed;
}

static void free_id(struct rtrs_srv_op *id)
{
        if (!id)
                return;
        kfree(id);
}

static void rtrs_srv_free_ops_ids(struct rtrs_srv_path *srv_path)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        int i;

        if (srv_path->ops_ids) {
                for (i = 0; i < srv->queue_depth; i++)
                        free_id(srv_path->ops_ids[i]);
                kfree(srv_path->ops_ids);
                srv_path->ops_ids = NULL;
        }
}

static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc);

static struct ib_cqe io_comp_cqe = {
        .done = rtrs_srv_rdma_done
};

static inline void rtrs_srv_inflight_ref_release(struct percpu_ref *ref)
{
        struct rtrs_srv_path *srv_path = container_of(ref,
                                                      struct rtrs_srv_path,
                                                      ids_inflight_ref);

        percpu_ref_exit(&srv_path->ids_inflight_ref);
        complete(&srv_path->complete_done);
}

static int rtrs_srv_alloc_ops_ids(struct rtrs_srv_path *srv_path)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_srv_op *id;
        int i, ret;

        srv_path->ops_ids = kzalloc_objs(*srv_path->ops_ids, srv->queue_depth);
        if (!srv_path->ops_ids)
                goto err;

        for (i = 0; i < srv->queue_depth; ++i) {
                id = kzalloc_obj(*id);
                if (!id)
                        goto err;

                srv_path->ops_ids[i] = id;
        }

        ret = percpu_ref_init(&srv_path->ids_inflight_ref,
                              rtrs_srv_inflight_ref_release, 0, GFP_KERNEL);
        if (ret) {
                pr_err("Percpu reference init failed\n");
                goto err;
        }
        init_completion(&srv_path->complete_done);

        return 0;

err:
        rtrs_srv_free_ops_ids(srv_path);
        return -ENOMEM;
}

static inline void rtrs_srv_get_ops_ids(struct rtrs_srv_path *srv_path)
{
        percpu_ref_get(&srv_path->ids_inflight_ref);
}

static inline void rtrs_srv_put_ops_ids(struct rtrs_srv_path *srv_path)
{
        percpu_ref_put(&srv_path->ids_inflight_ref);
}

static void rtrs_srv_reg_mr_done(struct ib_cq *cq, struct ib_wc *wc)
{
        struct rtrs_srv_con *con = to_srv_con(wc->qp->qp_context);
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);

        if (wc->status != IB_WC_SUCCESS) {
                rtrs_err_rl(s, "REG MR failed: %s\n",
                          ib_wc_status_msg(wc->status));
                close_path(srv_path);
                return;
        }
}

static struct ib_cqe local_reg_cqe = {
        .done = rtrs_srv_reg_mr_done
};

static int rdma_write_sg(struct rtrs_srv_op *id)
{
        struct rtrs_path *s = id->con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        dma_addr_t dma_addr = srv_path->dma_addr[id->msg_id];
        struct rtrs_srv_mr *srv_mr;
        struct ib_send_wr inv_wr;
        struct ib_rdma_wr imm_wr;
        struct ib_rdma_wr *wr = NULL;
        enum ib_send_flags flags;
        size_t sg_cnt;
        int err, offset;
        bool need_inval;
        struct ib_reg_wr rwr;
        struct ib_sge *plist;
        struct ib_sge list;

        sg_cnt = le16_to_cpu(id->rd_msg->sg_cnt);
        need_inval = le16_to_cpu(id->rd_msg->flags) & RTRS_MSG_NEED_INVAL_F;
        if (sg_cnt != 1)
                return -EINVAL;

        offset = 0;

        wr              = &id->tx_wr;
        plist           = &id->tx_sg;
        plist->addr     = dma_addr + offset;
        plist->length   = le32_to_cpu(id->rd_msg->desc[0].len);

        /* WR will fail with length error
         * if this is 0
         */
        if (plist->length == 0) {
                rtrs_err(s, "Invalid RDMA-Write sg list length 0\n");
                return -EINVAL;
        }

        plist->lkey = srv_path->s.dev->ib_pd->local_dma_lkey;
        offset += plist->length;

        wr->wr.sg_list  = plist;
        wr->wr.num_sge  = 1;
        wr->remote_addr = le64_to_cpu(id->rd_msg->desc[0].addr);
        wr->rkey        = le32_to_cpu(id->rd_msg->desc[0].key);

        wr->wr.opcode = IB_WR_RDMA_WRITE;
        wr->wr.wr_cqe   = &io_comp_cqe;
        wr->wr.ex.imm_data = 0;
        wr->wr.send_flags  = 0;

        if (need_inval && always_invalidate) {
                wr->wr.next = &rwr.wr;
                rwr.wr.next = &inv_wr;
                inv_wr.next = &imm_wr.wr;
        } else if (always_invalidate) {
                wr->wr.next = &rwr.wr;
                rwr.wr.next = &imm_wr.wr;
        } else if (need_inval) {
                wr->wr.next = &inv_wr;
                inv_wr.next = &imm_wr.wr;
        } else {
                wr->wr.next = &imm_wr.wr;
        }
        /*
         * From time to time we have to post signaled sends,
         * or send queue will fill up and only QP reset can help.
         */
        flags = (atomic_inc_return(&id->con->c.wr_cnt) % s->signal_interval) ?
                0 : IB_SEND_SIGNALED;

        if (need_inval) {
                inv_wr.sg_list = NULL;
                inv_wr.num_sge = 0;
                inv_wr.opcode = IB_WR_SEND_WITH_INV;
                inv_wr.wr_cqe   = &io_comp_cqe;
                inv_wr.send_flags = 0;
                inv_wr.ex.invalidate_rkey = wr->rkey;
        }

        imm_wr.wr.next = NULL;
        if (always_invalidate) {
                struct rtrs_msg_rkey_rsp *msg;

                srv_mr = &srv_path->mrs[id->msg_id];
                rwr.wr.opcode = IB_WR_REG_MR;
                rwr.wr.wr_cqe = &local_reg_cqe;
                rwr.wr.num_sge = 0;
                rwr.mr = srv_mr->mr;
                rwr.wr.send_flags = 0;
                rwr.key = srv_mr->mr->rkey;
                rwr.access = (IB_ACCESS_LOCAL_WRITE |
                              IB_ACCESS_REMOTE_WRITE);
                msg = srv_mr->iu->buf;
                msg->buf_id = cpu_to_le16(id->msg_id);
                msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
                msg->rkey = cpu_to_le32(srv_mr->mr->rkey);

                list.addr   = srv_mr->iu->dma_addr;
                list.length = sizeof(*msg);
                list.lkey   = srv_path->s.dev->ib_pd->local_dma_lkey;
                imm_wr.wr.sg_list = &list;
                imm_wr.wr.num_sge = 1;
                imm_wr.wr.opcode = IB_WR_SEND_WITH_IMM;
                ib_dma_sync_single_for_device(srv_path->s.dev->ib_dev,
                                              srv_mr->iu->dma_addr,
                                              srv_mr->iu->size, DMA_TO_DEVICE);
        } else {
                imm_wr.wr.sg_list = NULL;
                imm_wr.wr.num_sge = 0;
                imm_wr.wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
        }
        imm_wr.wr.send_flags = flags;
        imm_wr.wr.ex.imm_data = cpu_to_be32(rtrs_to_io_rsp_imm(id->msg_id,
                                                             0, need_inval));

        imm_wr.wr.wr_cqe   = &io_comp_cqe;
        ib_dma_sync_single_for_device(srv_path->s.dev->ib_dev, dma_addr,
                                      offset, DMA_BIDIRECTIONAL);

        err = ib_post_send(id->con->c.qp, &id->tx_wr.wr, NULL);
        if (err)
                rtrs_err(s,
                          "Posting RDMA-Write-Request to QP failed, err: %pe\n",
                          ERR_PTR(err));

        return err;
}

/**
 * send_io_resp_imm() - respond to client with empty IMM on failed READ/WRITE
 *                      requests or on successful WRITE request.
 * @con:        the connection to send back result
 * @id:         the id associated with the IO
 * @errno:      the error number of the IO.
 *
 * Return 0 on success, errno otherwise.
 */
static int send_io_resp_imm(struct rtrs_srv_con *con, struct rtrs_srv_op *id,
                            int errno)
{
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct ib_send_wr inv_wr, *wr = NULL;
        struct ib_rdma_wr imm_wr;
        struct ib_reg_wr rwr;
        struct rtrs_srv_mr *srv_mr;
        bool need_inval = false;
        enum ib_send_flags flags;
        struct ib_sge list;
        u32 imm;
        int err;

        if (id->dir == READ) {
                struct rtrs_msg_rdma_read *rd_msg = id->rd_msg;
                size_t sg_cnt;

                need_inval = le16_to_cpu(rd_msg->flags) &
                                RTRS_MSG_NEED_INVAL_F;
                sg_cnt = le16_to_cpu(rd_msg->sg_cnt);

                if (need_inval) {
                        if (sg_cnt) {
                                inv_wr.wr_cqe   = &io_comp_cqe;
                                inv_wr.sg_list = NULL;
                                inv_wr.num_sge = 0;
                                inv_wr.opcode = IB_WR_SEND_WITH_INV;
                                inv_wr.send_flags = 0;
                                /* Only one key is actually used */
                                inv_wr.ex.invalidate_rkey =
                                        le32_to_cpu(rd_msg->desc[0].key);
                        } else {
                                WARN_ON_ONCE(1);
                                need_inval = false;
                        }
                }
        }

        trace_send_io_resp_imm(id, need_inval, always_invalidate, errno);

        if (need_inval && always_invalidate) {
                wr = &inv_wr;
                inv_wr.next = &rwr.wr;
                rwr.wr.next = &imm_wr.wr;
        } else if (always_invalidate) {
                wr = &rwr.wr;
                rwr.wr.next = &imm_wr.wr;
        } else if (need_inval) {
                wr = &inv_wr;
                inv_wr.next = &imm_wr.wr;
        } else {
                wr = &imm_wr.wr;
        }
        /*
         * From time to time we have to post signalled sends,
         * or send queue will fill up and only QP reset can help.
         */
        flags = (atomic_inc_return(&con->c.wr_cnt) % s->signal_interval) ?
                0 : IB_SEND_SIGNALED;
        imm = rtrs_to_io_rsp_imm(id->msg_id, errno, need_inval);
        imm_wr.wr.next = NULL;
        if (always_invalidate) {
                struct rtrs_msg_rkey_rsp *msg;

                srv_mr = &srv_path->mrs[id->msg_id];
                rwr.wr.next = &imm_wr.wr;
                rwr.wr.opcode = IB_WR_REG_MR;
                rwr.wr.wr_cqe = &local_reg_cqe;
                rwr.wr.num_sge = 0;
                rwr.wr.send_flags = 0;
                rwr.mr = srv_mr->mr;
                rwr.key = srv_mr->mr->rkey;
                rwr.access = (IB_ACCESS_LOCAL_WRITE |
                              IB_ACCESS_REMOTE_WRITE);
                msg = srv_mr->iu->buf;
                msg->buf_id = cpu_to_le16(id->msg_id);
                msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
                msg->rkey = cpu_to_le32(srv_mr->mr->rkey);

                list.addr   = srv_mr->iu->dma_addr;
                list.length = sizeof(*msg);
                list.lkey   = srv_path->s.dev->ib_pd->local_dma_lkey;
                imm_wr.wr.sg_list = &list;
                imm_wr.wr.num_sge = 1;
                imm_wr.wr.opcode = IB_WR_SEND_WITH_IMM;
                ib_dma_sync_single_for_device(srv_path->s.dev->ib_dev,
                                              srv_mr->iu->dma_addr,
                                              srv_mr->iu->size, DMA_TO_DEVICE);
        } else {
                imm_wr.wr.sg_list = NULL;
                imm_wr.wr.num_sge = 0;
                imm_wr.wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
        }
        imm_wr.wr.send_flags = flags;
        imm_wr.wr.wr_cqe   = &io_comp_cqe;

        imm_wr.wr.ex.imm_data = cpu_to_be32(imm);

        err = ib_post_send(id->con->c.qp, wr, NULL);
        if (err)
                rtrs_err_rl(s, "Posting RDMA-Reply to QP failed, err: %pe\n",
                            ERR_PTR(err));

        return err;
}

void close_path(struct rtrs_srv_path *srv_path)
{
        if (rtrs_srv_change_state(srv_path, RTRS_SRV_CLOSING))
                queue_work(rtrs_wq, &srv_path->close_work);
        WARN_ON(srv_path->state != RTRS_SRV_CLOSING);
}

static inline const char *rtrs_srv_state_str(enum rtrs_srv_state state)
{
        switch (state) {
        case RTRS_SRV_CONNECTING:
                return "RTRS_SRV_CONNECTING";
        case RTRS_SRV_CONNECTED:
                return "RTRS_SRV_CONNECTED";
        case RTRS_SRV_CLOSING:
                return "RTRS_SRV_CLOSING";
        case RTRS_SRV_CLOSED:
                return "RTRS_SRV_CLOSED";
        default:
                return "UNKNOWN";
        }
}

/**
 * rtrs_srv_resp_rdma() - Finish an RDMA request
 *
 * @id:         Internal RTRS operation identifier
 * @status:     Response Code sent to the other side for this operation.
 *              0 = success, <=0 error
 * Context: any
 *
 * Finish a RDMA operation. A message is sent to the client and the
 * corresponding memory areas will be released.
 */
bool rtrs_srv_resp_rdma(struct rtrs_srv_op *id, int status)
{
        struct rtrs_srv_path *srv_path;
        struct rtrs_srv_con *con;
        struct rtrs_path *s;
        int err;

        if (WARN_ON(!id))
                return true;

        con = id->con;
        s = con->c.path;
        srv_path = to_srv_path(s);

        id->status = status;

        if (srv_path->state != RTRS_SRV_CONNECTED) {
                rtrs_err_rl(s,
                            "Sending I/O response failed,  server path %s is disconnected, path state %s\n",
                            kobject_name(&srv_path->kobj),
                            rtrs_srv_state_str(srv_path->state));
                goto out;
        }
        if (always_invalidate) {
                struct rtrs_srv_mr *mr = &srv_path->mrs[id->msg_id];

                ib_update_fast_reg_key(mr->mr, ib_inc_rkey(mr->mr->rkey));
        }
        if (atomic_sub_return(1, &con->c.sq_wr_avail) < 0) {
                rtrs_err(s, "IB send queue full: srv_path=%s cid=%d\n",
                         kobject_name(&srv_path->kobj),
                         con->c.cid);
                atomic_add(1, &con->c.sq_wr_avail);
                spin_lock(&con->rsp_wr_wait_lock);
                list_add_tail(&id->wait_list, &con->rsp_wr_wait_list);
                spin_unlock(&con->rsp_wr_wait_lock);
                return false;
        }

        if (status || id->dir == WRITE || !id->rd_msg->sg_cnt)
                err = send_io_resp_imm(con, id, status);
        else
                err = rdma_write_sg(id);

        if (err) {
                rtrs_err_rl(s, "IO response failed: %pe: srv_path=%s\n",
                            ERR_PTR(err), kobject_name(&srv_path->kobj));
                close_path(srv_path);
        }
out:
        rtrs_srv_put_ops_ids(srv_path);
        return true;
}
EXPORT_SYMBOL(rtrs_srv_resp_rdma);

/**
 * rtrs_srv_set_sess_priv() - Set private pointer in rtrs_srv.
 * @srv:        Session pointer
 * @priv:       The private pointer that is associated with the session.
 */
void rtrs_srv_set_sess_priv(struct rtrs_srv_sess *srv, void *priv)
{
        srv->priv = priv;
}
EXPORT_SYMBOL(rtrs_srv_set_sess_priv);

static void unmap_cont_bufs(struct rtrs_srv_path *srv_path)
{
        int i;

        for (i = 0; i < srv_path->mrs_num; i++) {
                struct rtrs_srv_mr *srv_mr;

                srv_mr = &srv_path->mrs[i];

                if (always_invalidate)
                        rtrs_iu_free(srv_mr->iu, srv_path->s.dev->ib_dev, 1);

                ib_dereg_mr(srv_mr->mr);
                ib_dma_unmap_sg(srv_path->s.dev->ib_dev, srv_mr->sgt.sgl,
                                srv_mr->sgt.nents, DMA_BIDIRECTIONAL);
                sg_free_table(&srv_mr->sgt);
        }
        kfree(srv_path->mrs);
}

static int map_cont_bufs(struct rtrs_srv_path *srv_path)
{
        struct ib_device *ib_dev = srv_path->s.dev->ib_dev;
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_path *ss = &srv_path->s;
        int i, err, mrs_num;
        unsigned int chunk_bits;
        enum ib_mr_type mr_type;
        int chunks_per_mr = 1;
        struct sg_table *sgt;
        struct ib_mr *mr;

        /*
         * Here we map queue_depth chunks to MR.  Firstly we have to
         * figure out how many chunks can we map per MR.
         */
        if (always_invalidate) {
                /*
                 * in order to do invalidate for each chunks of memory, we needs
                 * more memory regions.
                 */
                mrs_num = srv->queue_depth;
        } else {
                chunks_per_mr =
                        srv_path->s.dev->ib_dev->attrs.max_fast_reg_page_list_len;
                mrs_num = DIV_ROUND_UP(srv->queue_depth, chunks_per_mr);
                chunks_per_mr = DIV_ROUND_UP(srv->queue_depth, mrs_num);
        }

        srv_path->mrs = kzalloc_objs(*srv_path->mrs, mrs_num);
        if (!srv_path->mrs)
                return -ENOMEM;

        for (srv_path->mrs_num = 0; srv_path->mrs_num < mrs_num;
             srv_path->mrs_num++) {
                struct rtrs_srv_mr *srv_mr = &srv_path->mrs[srv_path->mrs_num];
                struct scatterlist *s;
                int nr, nr_sgt, chunks, ind;

                sgt = &srv_mr->sgt;
                chunks = chunks_per_mr * srv_path->mrs_num;
                if (!always_invalidate)
                        chunks_per_mr = min_t(int, chunks_per_mr,
                                              srv->queue_depth - chunks);

                err = sg_alloc_table(sgt, chunks_per_mr, GFP_KERNEL);
                if (err)
                        goto err;

                for_each_sg(sgt->sgl, s, chunks_per_mr, i)
                        sg_set_page(s, srv->chunks[chunks + i],
                                    max_chunk_size, 0);

                nr_sgt = ib_dma_map_sg(srv_path->s.dev->ib_dev, sgt->sgl,
                                   sgt->nents, DMA_BIDIRECTIONAL);
                if (!nr_sgt) {
                        err = -EINVAL;
                        goto free_sg;
                }

                if (ib_dev->attrs.kernel_cap_flags & IBK_SG_GAPS_REG)
                        mr_type = IB_MR_TYPE_SG_GAPS;
                else
                        mr_type = IB_MR_TYPE_MEM_REG;

                mr = ib_alloc_mr(srv_path->s.dev->ib_pd, mr_type, nr_sgt);
                if (IS_ERR(mr)) {
                        err = PTR_ERR(mr);
                        goto unmap_sg;
                }
                nr = ib_map_mr_sg(mr, sgt->sgl, nr_sgt,
                                  NULL, max_chunk_size);
                if (nr < nr_sgt) {
                        err = nr < 0 ? nr : -EINVAL;
                        goto dereg_mr;
                }

                if (always_invalidate) {
                        srv_mr->iu = rtrs_iu_alloc(1,
                                        sizeof(struct rtrs_msg_rkey_rsp),
                                        GFP_KERNEL, srv_path->s.dev->ib_dev,
                                        DMA_TO_DEVICE, rtrs_srv_rdma_done);
                        if (!srv_mr->iu) {
                                err = -ENOMEM;
                                rtrs_err(ss, "rtrs_iu_alloc(), err: %pe\n", ERR_PTR(err));
                                goto dereg_mr;
                        }
                }

                /*
                 * Cache DMA addresses by traversing sg entries.  If
                 * regions were merged, an inner loop is required to
                 * populate the DMA address array by traversing larger
                 * regions.
                 */
                ind = chunks;
                for_each_sg(sgt->sgl, s, nr_sgt, i) {
                        unsigned int dma_len = sg_dma_len(s);
                        u64 dma_addr = sg_dma_address(s);
                        u64 dma_addr_end = dma_addr + dma_len;

                        do {
                                srv_path->dma_addr[ind++] = dma_addr;
                                dma_addr += max_chunk_size;
                        } while (dma_addr < dma_addr_end);
                }

                ib_update_fast_reg_key(mr, ib_inc_rkey(mr->rkey));
                srv_mr->mr = mr;
        }

        chunk_bits = ilog2(srv->queue_depth - 1) + 1;
        srv_path->mem_bits = (MAX_IMM_PAYL_BITS - chunk_bits);

        return 0;

dereg_mr:
        ib_dereg_mr(mr);
unmap_sg:
        ib_dma_unmap_sg(srv_path->s.dev->ib_dev, sgt->sgl,
                        sgt->nents, DMA_BIDIRECTIONAL);
free_sg:
        sg_free_table(sgt);
err:
        unmap_cont_bufs(srv_path);

        return err;
}

static void rtrs_srv_hb_err_handler(struct rtrs_con *c)
{
        struct rtrs_srv_con *con = container_of(c, typeof(*con), c);
        struct rtrs_srv_path *srv_path = to_srv_path(con->c.path);

        rtrs_err(con->c.path, "HB err handler for path=%s\n", kobject_name(&srv_path->kobj));
        close_path(to_srv_path(c->path));
}

static void rtrs_srv_init_hb(struct rtrs_srv_path *srv_path)
{
        rtrs_init_hb(&srv_path->s, &io_comp_cqe,
                      RTRS_HB_INTERVAL_MS,
                      RTRS_HB_MISSED_MAX,
                      rtrs_srv_hb_err_handler,
                      rtrs_wq);
}

static void rtrs_srv_start_hb(struct rtrs_srv_path *srv_path)
{
        rtrs_start_hb(&srv_path->s);
}

static void rtrs_srv_stop_hb(struct rtrs_srv_path *srv_path)
{
        rtrs_stop_hb(&srv_path->s);
}

static void rtrs_srv_info_rsp_done(struct ib_cq *cq, struct ib_wc *wc)
{
        struct rtrs_srv_con *con = to_srv_con(wc->qp->qp_context);
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_iu *iu;

        iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
        rtrs_iu_free(iu, srv_path->s.dev->ib_dev, 1);

        if (wc->status != IB_WC_SUCCESS) {
                rtrs_err(s, "Sess info response send failed: %s\n",
                          ib_wc_status_msg(wc->status));
                close_path(srv_path);
                return;
        }
        WARN_ON(wc->opcode != IB_WC_SEND);
}

static int rtrs_srv_path_up(struct rtrs_srv_path *srv_path)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_srv_ctx *ctx = srv->ctx;
        int up, ret = 0;

        mutex_lock(&srv->paths_ev_mutex);
        up = ++srv->paths_up;
        if (up == 1)
                ret = ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_CONNECTED, NULL);
        mutex_unlock(&srv->paths_ev_mutex);

        /* Mark session as established */
        if (!ret)
                srv_path->established = true;

        return ret;
}

static void rtrs_srv_path_down(struct rtrs_srv_path *srv_path)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_srv_ctx *ctx = srv->ctx;

        if (!srv_path->established)
                return;

        srv_path->established = false;
        mutex_lock(&srv->paths_ev_mutex);
        WARN_ON(!srv->paths_up);
        if (--srv->paths_up == 0)
                ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_DISCONNECTED, srv->priv);
        mutex_unlock(&srv->paths_ev_mutex);
}

static bool exist_pathname(struct rtrs_srv_ctx *ctx,
                           const char *pathname, const uuid_t *path_uuid)
{
        struct rtrs_srv_sess *srv;
        struct rtrs_srv_path *srv_path;
        bool found = false;

        mutex_lock(&ctx->srv_mutex);
        list_for_each_entry(srv, &ctx->srv_list, ctx_list) {
                mutex_lock(&srv->paths_mutex);

                /* when a client with same uuid and same sessname tried to add a path */
                if (uuid_equal(&srv->paths_uuid, path_uuid)) {
                        mutex_unlock(&srv->paths_mutex);
                        continue;
                }

                list_for_each_entry(srv_path, &srv->paths_list, s.entry) {
                        if (strlen(srv_path->s.sessname) == strlen(pathname) &&
                            !strcmp(srv_path->s.sessname, pathname)) {
                                found = true;
                                break;
                        }
                }
                mutex_unlock(&srv->paths_mutex);
                if (found)
                        break;
        }
        mutex_unlock(&ctx->srv_mutex);
        return found;
}

static int post_recv_path(struct rtrs_srv_path *srv_path);
static int rtrs_rdma_do_reject(struct rdma_cm_id *cm_id, int errno);

static int process_info_req(struct rtrs_srv_con *con,
                            struct rtrs_msg_info_req *msg)
{
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct ib_send_wr *reg_wr = NULL;
        struct rtrs_msg_info_rsp *rsp;
        struct rtrs_iu *tx_iu;
        struct ib_reg_wr *rwr;
        int mri, err;
        size_t tx_sz;

        err = post_recv_path(srv_path);
        if (err) {
                rtrs_err(s, "post_recv_path(), err: %pe\n", ERR_PTR(err));
                return err;
        }

        if (strchr(msg->pathname, '/') || strchr(msg->pathname, '.')) {
                rtrs_err(s, "pathname cannot contain / and .\n");
                return -EINVAL;
        }

        if (exist_pathname(srv_path->srv->ctx,
                           msg->pathname, &srv_path->srv->paths_uuid)) {
                rtrs_err(s, "pathname is duplicated: %s\n", msg->pathname);
                return -EPERM;
        }
        strscpy(srv_path->s.sessname, msg->pathname,
                sizeof(srv_path->s.sessname));

        rwr = kzalloc_objs(*rwr, srv_path->mrs_num);
        if (!rwr)
                return -ENOMEM;

        tx_sz  = sizeof(*rsp);
        tx_sz += sizeof(rsp->desc[0]) * srv_path->mrs_num;
        tx_iu = rtrs_iu_alloc(1, tx_sz, GFP_KERNEL, srv_path->s.dev->ib_dev,
                               DMA_TO_DEVICE, rtrs_srv_info_rsp_done);
        if (!tx_iu) {
                err = -ENOMEM;
                goto rwr_free;
        }

        rsp = tx_iu->buf;
        rsp->type = cpu_to_le16(RTRS_MSG_INFO_RSP);
        rsp->sg_cnt = cpu_to_le16(srv_path->mrs_num);

        for (mri = 0; mri < srv_path->mrs_num; mri++) {
                struct ib_mr *mr = srv_path->mrs[mri].mr;

                rsp->desc[mri].addr = cpu_to_le64(mr->iova);
                rsp->desc[mri].key  = cpu_to_le32(mr->rkey);
                rsp->desc[mri].len  = cpu_to_le32(mr->length);

                /*
                 * Fill in reg MR request and chain them *backwards*
                 */
                rwr[mri].wr.next = mri ? &rwr[mri - 1].wr : NULL;
                rwr[mri].wr.opcode = IB_WR_REG_MR;
                rwr[mri].wr.wr_cqe = &local_reg_cqe;
                rwr[mri].wr.num_sge = 0;
                rwr[mri].wr.send_flags = 0;
                rwr[mri].mr = mr;
                rwr[mri].key = mr->rkey;
                rwr[mri].access = (IB_ACCESS_LOCAL_WRITE |
                                   IB_ACCESS_REMOTE_WRITE);
                reg_wr = &rwr[mri].wr;
        }

        err = rtrs_srv_create_path_files(srv_path);
        if (err)
                goto iu_free;
        kobject_get(&srv_path->kobj);
        get_device(&srv_path->srv->dev);
        err = rtrs_srv_change_state(srv_path, RTRS_SRV_CONNECTED);
        if (!err) {
                rtrs_err(s, "rtrs_srv_change_state() failed\n");
                goto iu_free;
        }

        rtrs_srv_start_hb(srv_path);

        /*
         * We do not account number of established connections at the current
         * moment, we rely on the client, which should send info request when
         * all connections are successfully established.  Thus, simply notify
         * listener with a proper event if we are the first path.
         */
        err = rtrs_srv_path_up(srv_path);
        if (err) {
                rtrs_err(s, "rtrs_srv_path_up(), err: %pe\n", ERR_PTR(err));
                goto iu_free;
        }

        ib_dma_sync_single_for_device(srv_path->s.dev->ib_dev,
                                      tx_iu->dma_addr,
                                      tx_iu->size, DMA_TO_DEVICE);

        /*
         * Now disable zombie connection closing. Since from the logs and code,
         * we know that it can never be in CONNECTED state.
         */
        srv_path->connection_timeout = 0;

        /* Send info response */
        err = rtrs_iu_post_send(&con->c, tx_iu, tx_sz, reg_wr);
        if (err) {
                rtrs_err(s, "rtrs_iu_post_send(), err: %pe\n", ERR_PTR(err));
iu_free:
                rtrs_iu_free(tx_iu, srv_path->s.dev->ib_dev, 1);
        }
rwr_free:
        kfree(rwr);

        return err;
}

static void rtrs_srv_info_req_done(struct ib_cq *cq, struct ib_wc *wc)
{
        struct rtrs_srv_con *con = to_srv_con(wc->qp->qp_context);
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_msg_info_req *msg;
        struct rtrs_iu *iu;
        int err;

        WARN_ON(con->c.cid);

        iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
        if (wc->status != IB_WC_SUCCESS) {
                rtrs_err(s, "Sess info request receive failed: %s\n",
                          ib_wc_status_msg(wc->status));
                goto close;
        }
        WARN_ON(wc->opcode != IB_WC_RECV);

        if (wc->byte_len < sizeof(*msg)) {
                rtrs_err(s, "Sess info request is malformed: size %d\n",
                          wc->byte_len);
                goto close;
        }
        ib_dma_sync_single_for_cpu(srv_path->s.dev->ib_dev, iu->dma_addr,
                                   iu->size, DMA_FROM_DEVICE);
        msg = iu->buf;
        if (le16_to_cpu(msg->type) != RTRS_MSG_INFO_REQ) {
                rtrs_err(s, "Sess info request is malformed: type %d\n",
                          le16_to_cpu(msg->type));
                goto close;
        }
        err = process_info_req(con, msg);
        if (err)
                goto close;

        rtrs_iu_free(iu, srv_path->s.dev->ib_dev, 1);
        return;
close:
        rtrs_iu_free(iu, srv_path->s.dev->ib_dev, 1);
        close_path(srv_path);
}

static int post_recv_info_req(struct rtrs_srv_con *con)
{
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_iu *rx_iu;
        int err;

        rx_iu = rtrs_iu_alloc(1, sizeof(struct rtrs_msg_info_req),
                               GFP_KERNEL, srv_path->s.dev->ib_dev,
                               DMA_FROM_DEVICE, rtrs_srv_info_req_done);
        if (!rx_iu)
                return -ENOMEM;
        /* Prepare for getting info response */
        err = rtrs_iu_post_recv(&con->c, rx_iu);
        if (err) {
                rtrs_err(s, "rtrs_iu_post_recv(), err: %pe\n", ERR_PTR(err));
                rtrs_iu_free(rx_iu, srv_path->s.dev->ib_dev, 1);
                return err;
        }

        return 0;
}

static int post_recv_io(struct rtrs_srv_con *con, size_t q_size)
{
        int i, err;

        for (i = 0; i < q_size; i++) {
                err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
                if (err)
                        return err;
        }

        return 0;
}

static int post_recv_path(struct rtrs_srv_path *srv_path)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_path *s = &srv_path->s;
        size_t q_size;
        int err, cid;

        for (cid = 0; cid < srv_path->s.con_num; cid++) {
                if (cid == 0)
                        q_size = SERVICE_CON_QUEUE_DEPTH;
                else
                        q_size = srv->queue_depth;
                if (srv_path->state != RTRS_SRV_CONNECTING) {
                        rtrs_err(s, "Path state invalid. state %s\n",
                                 rtrs_srv_state_str(srv_path->state));
                        return -EIO;
                }

                if (!srv_path->s.con[cid]) {
                        rtrs_err(s, "Conn not set for %d\n", cid);
                        return -EIO;
                }

                err = post_recv_io(to_srv_con(srv_path->s.con[cid]), q_size);
                if (err) {
                        rtrs_err(s, "post_recv_io(), err: %pe\n", ERR_PTR(err));
                        return err;
                }
        }

        return 0;
}

static void process_read(struct rtrs_srv_con *con,
                         struct rtrs_msg_rdma_read *msg,
                         u32 buf_id, u32 off)
{
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_srv_ctx *ctx = srv->ctx;
        struct rtrs_srv_op *id;

        size_t usr_len, data_len;
        void *data;
        int ret;

        if (srv_path->state != RTRS_SRV_CONNECTED) {
                rtrs_err_rl(s,
                             "Processing read request failed,  session is disconnected, sess state %s\n",
                             rtrs_srv_state_str(srv_path->state));
                return;
        }
        if (msg->sg_cnt != 1 && msg->sg_cnt != 0) {
                rtrs_err_rl(s,
                            "Processing read request failed, invalid message\n");
                return;
        }
        rtrs_srv_get_ops_ids(srv_path);
        rtrs_srv_update_rdma_stats(srv_path->stats, off, READ);
        id = srv_path->ops_ids[buf_id];
        id->con         = con;
        id->dir         = READ;
        id->msg_id      = buf_id;
        id->rd_msg      = msg;
        usr_len = le16_to_cpu(msg->usr_len);
        data_len = off - usr_len;
        data = page_address(srv->chunks[buf_id]);
        ret = ctx->ops.rdma_ev(srv->priv, id, data, data_len,
                           data + data_len, usr_len);

        if (ret) {
                rtrs_err_rl(s,
                             "Processing read request failed, user module cb reported for msg_id %d, err: %pe\n",
                             buf_id, ERR_PTR(ret));
                goto send_err_msg;
        }

        return;

send_err_msg:
        ret = send_io_resp_imm(con, id, ret);
        if (ret < 0) {
                rtrs_err_rl(s,
                             "Sending err msg for failed RDMA-Write-Req failed, msg_id %d, err: %pe\n",
                             buf_id, ERR_PTR(ret));
                close_path(srv_path);
        }
        rtrs_srv_put_ops_ids(srv_path);
}

static void process_write(struct rtrs_srv_con *con,
                          struct rtrs_msg_rdma_write *req,
                          u32 buf_id, u32 off)
{
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_srv_ctx *ctx = srv->ctx;
        struct rtrs_srv_op *id;

        size_t data_len, usr_len;
        void *data;
        int ret;

        if (srv_path->state != RTRS_SRV_CONNECTED) {
                rtrs_err_rl(s,
                             "Processing write request failed,  session is disconnected, sess state %s\n",
                             rtrs_srv_state_str(srv_path->state));
                return;
        }
        rtrs_srv_get_ops_ids(srv_path);
        rtrs_srv_update_rdma_stats(srv_path->stats, off, WRITE);
        id = srv_path->ops_ids[buf_id];
        id->con    = con;
        id->dir    = WRITE;
        id->msg_id = buf_id;

        usr_len = le16_to_cpu(req->usr_len);
        data_len = off - usr_len;
        data = page_address(srv->chunks[buf_id]);
        ret = ctx->ops.rdma_ev(srv->priv, id, data, data_len,
                               data + data_len, usr_len);
        if (ret) {
                rtrs_err_rl(s,
                             "Processing write request failed, user module callback reports err: %pe\n",
                             ERR_PTR(ret));
                goto send_err_msg;
        }

        return;

send_err_msg:
        ret = send_io_resp_imm(con, id, ret);
        if (ret < 0) {
                rtrs_err_rl(s,
                             "Processing write request failed, sending I/O response failed, msg_id %d, err: %pe\n",
                             buf_id, ERR_PTR(ret));
                close_path(srv_path);
        }
        rtrs_srv_put_ops_ids(srv_path);
}

static void process_io_req(struct rtrs_srv_con *con, void *msg,
                           u32 id, u32 off)
{
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_msg_rdma_hdr *hdr;
        unsigned int type;

        ib_dma_sync_single_for_cpu(srv_path->s.dev->ib_dev,
                                   srv_path->dma_addr[id],
                                   max_chunk_size, DMA_BIDIRECTIONAL);
        hdr = msg;
        type = le16_to_cpu(hdr->type);

        switch (type) {
        case RTRS_MSG_WRITE:
                process_write(con, msg, id, off);
                break;
        case RTRS_MSG_READ:
                process_read(con, msg, id, off);
                break;
        default:
                rtrs_err(s,
                          "Processing I/O request failed, unknown message type received: 0x%02x\n",
                          type);
                goto err;
        }

        return;

err:
        close_path(srv_path);
}

static void rtrs_srv_inv_rkey_done(struct ib_cq *cq, struct ib_wc *wc)
{
        struct rtrs_srv_mr *mr =
                container_of(wc->wr_cqe, typeof(*mr), inv_cqe);
        struct rtrs_srv_con *con = to_srv_con(wc->qp->qp_context);
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_srv_sess *srv = srv_path->srv;
        u32 msg_id, off;
        void *data;

        if (wc->status != IB_WC_SUCCESS) {
                rtrs_err(s, "Failed IB_WR_LOCAL_INV: %s\n",
                          ib_wc_status_msg(wc->status));
                close_path(srv_path);
        }
        msg_id = mr->msg_id;
        off = mr->msg_off;
        data = page_address(srv->chunks[msg_id]) + off;
        process_io_req(con, data, msg_id, off);
}

static int rtrs_srv_inv_rkey(struct rtrs_srv_con *con,
                              struct rtrs_srv_mr *mr)
{
        struct ib_send_wr wr = {
                .opcode             = IB_WR_LOCAL_INV,
                .wr_cqe             = &mr->inv_cqe,
                .send_flags         = IB_SEND_SIGNALED,
                .ex.invalidate_rkey = mr->mr->rkey,
        };
        mr->inv_cqe.done = rtrs_srv_inv_rkey_done;

        return ib_post_send(con->c.qp, &wr, NULL);
}

static void rtrs_rdma_process_wr_wait_list(struct rtrs_srv_con *con)
{
        spin_lock(&con->rsp_wr_wait_lock);
        while (!list_empty(&con->rsp_wr_wait_list)) {
                struct rtrs_srv_op *id;
                int ret;

                id = list_entry(con->rsp_wr_wait_list.next,
                                struct rtrs_srv_op, wait_list);
                list_del(&id->wait_list);

                spin_unlock(&con->rsp_wr_wait_lock);
                ret = rtrs_srv_resp_rdma(id, id->status);
                spin_lock(&con->rsp_wr_wait_lock);

                if (!ret) {
                        list_add(&id->wait_list, &con->rsp_wr_wait_list);
                        break;
                }
        }
        spin_unlock(&con->rsp_wr_wait_lock);
}

static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc)
{
        struct rtrs_srv_con *con = to_srv_con(wc->qp->qp_context);
        struct rtrs_path *s = con->c.path;
        struct rtrs_srv_path *srv_path = to_srv_path(s);
        struct rtrs_srv_sess *srv = srv_path->srv;
        u32 imm_type, imm_payload;
        int err;

        if (wc->status != IB_WC_SUCCESS) {
                if (wc->status != IB_WC_WR_FLUSH_ERR) {
                        rtrs_err(s,
                                  "%s (wr_cqe: %p, type: %d, vendor_err: 0x%x, len: %u)\n",
                                  ib_wc_status_msg(wc->status), wc->wr_cqe,
                                  wc->opcode, wc->vendor_err, wc->byte_len);
                        close_path(srv_path);
                }
                return;
        }

        switch (wc->opcode) {
        case IB_WC_RECV_RDMA_WITH_IMM:
                /*
                 * post_recv() RDMA write completions of IO reqs (read/write)
                 * and hb
                 */
                if (WARN_ON(wc->wr_cqe != &io_comp_cqe))
                        return;
                srv_path->s.hb_missed_cnt = 0;
                err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
                if (err) {
                        rtrs_err(s, "rtrs_post_recv(), err: %pe\n",
                                 ERR_PTR(err));
                        close_path(srv_path);
                        break;
                }
                rtrs_from_imm(be32_to_cpu(wc->ex.imm_data),
                               &imm_type, &imm_payload);
                if (imm_type == RTRS_IO_REQ_IMM) {
                        u32 msg_id, off;
                        void *data;

                        msg_id = imm_payload >> srv_path->mem_bits;
                        off = imm_payload & ((1 << srv_path->mem_bits) - 1);
                        if (msg_id >= srv->queue_depth || off >= max_chunk_size) {
                                rtrs_err(s, "Wrong msg_id %u, off %u\n",
                                          msg_id, off);
                                close_path(srv_path);
                                return;
                        }
                        if (always_invalidate) {
                                struct rtrs_srv_mr *mr = &srv_path->mrs[msg_id];

                                mr->msg_off = off;
                                mr->msg_id = msg_id;
                                err = rtrs_srv_inv_rkey(con, mr);
                                if (err) {
                                        rtrs_err(s, "rtrs_post_recv(), err: %pe\n",
                                                 ERR_PTR(err));
                                        close_path(srv_path);
                                        break;
                                }
                        } else {
                                data = page_address(srv->chunks[msg_id]) + off;
                                process_io_req(con, data, msg_id, off);
                        }
                } else if (imm_type == RTRS_HB_MSG_IMM) {
                        WARN_ON(con->c.cid);
                        rtrs_send_hb_ack(&srv_path->s);
                } else if (imm_type == RTRS_HB_ACK_IMM) {
                        WARN_ON(con->c.cid);
                        srv_path->s.hb_missed_cnt = 0;
                } else {
                        rtrs_wrn(s, "Unknown IMM type %u\n", imm_type);
                }
                break;
        case IB_WC_RDMA_WRITE:
        case IB_WC_SEND:
                /*
                 * post_send() RDMA write completions of IO reqs (read/write)
                 * and hb.
                 */
                atomic_add(s->signal_interval, &con->c.sq_wr_avail);

                if (!list_empty_careful(&con->rsp_wr_wait_list))
                        rtrs_rdma_process_wr_wait_list(con);

                break;
        default:
                rtrs_wrn(s, "Unexpected WC type: %d\n", wc->opcode);
                return;
        }
}

/**
 * rtrs_srv_get_path_name() - Get rtrs_srv peer hostname.
 * @srv:        Session
 * @pathname:   Pathname buffer
 * @len:        Length of sessname buffer
 */
int rtrs_srv_get_path_name(struct rtrs_srv_sess *srv, char *pathname,
                           size_t len)
{
        struct rtrs_srv_path *srv_path;
        int err = -ENOTCONN;

        mutex_lock(&srv->paths_mutex);
        list_for_each_entry(srv_path, &srv->paths_list, s.entry) {
                if (srv_path->state != RTRS_SRV_CONNECTED)
                        continue;
                strscpy(pathname, srv_path->s.sessname,
                        min_t(size_t, sizeof(srv_path->s.sessname), len));
                err = 0;
                break;
        }
        mutex_unlock(&srv->paths_mutex);

        return err;
}
EXPORT_SYMBOL(rtrs_srv_get_path_name);

/**
 * rtrs_srv_get_queue_depth() - Get rtrs_srv qdepth.
 * @srv:        Session
 */
int rtrs_srv_get_queue_depth(struct rtrs_srv_sess *srv)
{
        return srv->queue_depth;
}
EXPORT_SYMBOL(rtrs_srv_get_queue_depth);

static int find_next_bit_ring(struct rtrs_srv_path *srv_path)
{
        struct ib_device *ib_dev = srv_path->s.dev->ib_dev;
        int v;

        v = cpumask_next(srv_path->cur_cq_vector, &cq_affinity_mask);
        if (v >= nr_cpu_ids || v >= ib_dev->num_comp_vectors)
                v = cpumask_first(&cq_affinity_mask);
        return v;
}

static int rtrs_srv_get_next_cq_vector(struct rtrs_srv_path *srv_path)
{
        srv_path->cur_cq_vector = find_next_bit_ring(srv_path);

        return srv_path->cur_cq_vector;
}

static void rtrs_srv_dev_release(struct device *dev)
{
        struct rtrs_srv_sess *srv = container_of(dev, struct rtrs_srv_sess,
                                                 dev);

        kfree(srv);
}

static void free_srv(struct rtrs_srv_sess *srv)
{
        int i;

        WARN_ON(refcount_read(&srv->refcount));
        for (i = 0; i < srv->queue_depth; i++)
                __free_pages(srv->chunks[i], get_order(max_chunk_size));
        kfree(srv->chunks);
        mutex_destroy(&srv->paths_mutex);
        mutex_destroy(&srv->paths_ev_mutex);
        /* last put to release the srv structure */
        put_device(&srv->dev);
}

static struct rtrs_srv_sess *get_or_create_srv(struct rtrs_srv_ctx *ctx,
                                          const uuid_t *paths_uuid,
                                          bool first_conn)
{
        struct rtrs_srv_sess *srv;
        int i;

        mutex_lock(&ctx->srv_mutex);
        list_for_each_entry(srv, &ctx->srv_list, ctx_list) {
                if (uuid_equal(&srv->paths_uuid, paths_uuid) &&
                    refcount_inc_not_zero(&srv->refcount)) {
                        mutex_unlock(&ctx->srv_mutex);
                        return srv;
                }
        }
        mutex_unlock(&ctx->srv_mutex);
        /*
         * If this request is not the first connection request from the
         * client for this session then fail and return error.
         */
        if (!first_conn) {
                pr_err_ratelimited("Error: Not the first connection request for this session\n");
                return ERR_PTR(-ENXIO);
        }

        /* need to allocate a new srv */
        srv = kzalloc_obj(*srv);
        if  (!srv)
                return ERR_PTR(-ENOMEM);

        INIT_LIST_HEAD(&srv->paths_list);
        mutex_init(&srv->paths_mutex);
        mutex_init(&srv->paths_ev_mutex);
        uuid_copy(&srv->paths_uuid, paths_uuid);
        srv->queue_depth = sess_queue_depth;
        srv->ctx = ctx;
        device_initialize(&srv->dev);
        srv->dev.release = rtrs_srv_dev_release;

        srv->chunks = kzalloc_objs(*srv->chunks, srv->queue_depth);
        if (!srv->chunks)
                goto err_free_srv;

        for (i = 0; i < srv->queue_depth; i++) {
                srv->chunks[i] = alloc_pages(GFP_KERNEL,
                                             get_order(max_chunk_size));
                if (!srv->chunks[i])
                        goto err_free_chunks;
        }
        refcount_set(&srv->refcount, 1);
        mutex_lock(&ctx->srv_mutex);
        list_add(&srv->ctx_list, &ctx->srv_list);
        mutex_unlock(&ctx->srv_mutex);

        return srv;

err_free_chunks:
        while (i--)
                __free_pages(srv->chunks[i], get_order(max_chunk_size));
        kfree(srv->chunks);

err_free_srv:
        put_device(&srv->dev);
        return ERR_PTR(-ENOMEM);
}

static void put_srv(struct rtrs_srv_sess *srv)
{
        if (refcount_dec_and_test(&srv->refcount)) {
                struct rtrs_srv_ctx *ctx = srv->ctx;

                WARN_ON(srv->dev.kobj.state_in_sysfs);

                mutex_lock(&ctx->srv_mutex);
                list_del(&srv->ctx_list);
                mutex_unlock(&ctx->srv_mutex);
                free_srv(srv);
        }
}

static void __add_path_to_srv(struct rtrs_srv_sess *srv,
                              struct rtrs_srv_path *srv_path)
{
        list_add_tail(&srv_path->s.entry, &srv->paths_list);
        srv->paths_num++;
        WARN_ON(srv->paths_num >= MAX_PATHS_NUM);
}

static void del_path_from_srv(struct rtrs_srv_path *srv_path)
{
        struct rtrs_srv_sess *srv = srv_path->srv;

        if (WARN_ON(!srv))
                return;

        mutex_lock(&srv->paths_mutex);
        list_del(&srv_path->s.entry);
        WARN_ON(!srv->paths_num);
        srv->paths_num--;
        mutex_unlock(&srv->paths_mutex);
}

/* return true if addresses are the same, error other wise */
static int sockaddr_cmp(const struct sockaddr *a, const struct sockaddr *b)
{
        switch (a->sa_family) {
        case AF_IB:
                return memcmp(&((struct sockaddr_ib *)a)->sib_addr,
                              &((struct sockaddr_ib *)b)->sib_addr,
                              sizeof(struct ib_addr)) &&
                        (b->sa_family == AF_IB);
        case AF_INET:
                return memcmp(&((struct sockaddr_in *)a)->sin_addr,
                              &((struct sockaddr_in *)b)->sin_addr,
                              sizeof(struct in_addr)) &&
                        (b->sa_family == AF_INET);
        case AF_INET6:
                return memcmp(&((struct sockaddr_in6 *)a)->sin6_addr,
                              &((struct sockaddr_in6 *)b)->sin6_addr,
                              sizeof(struct in6_addr)) &&
                        (b->sa_family == AF_INET6);
        default:
                return -ENOENT;
        }
}

/* Let's close connections which have been waiting for more than 30 seconds */
#define RTRS_MAX_CONN_TIMEOUT 30000

static void rtrs_srv_check_close_path(struct rtrs_srv_path *srv_path)
{
        struct rtrs_path *s = &srv_path->s;

        if (srv_path->state == RTRS_SRV_CONNECTING && srv_path->connection_timeout &&
           (jiffies_to_msecs(jiffies - srv_path->connection_timeout) > RTRS_MAX_CONN_TIMEOUT)) {
                rtrs_err(s, "Closing zombie path\n");
                close_path(srv_path);
        }
}

static bool __is_path_w_addr_exists(struct rtrs_srv_sess *srv,
                                    struct rdma_addr *addr)
{
        struct rtrs_srv_path *srv_path;

        list_for_each_entry(srv_path, &srv->paths_list, s.entry) {
                if (!sockaddr_cmp((struct sockaddr *)&srv_path->s.dst_addr,
                                  (struct sockaddr *)&addr->dst_addr) &&
                    !sockaddr_cmp((struct sockaddr *)&srv_path->s.src_addr,
                                  (struct sockaddr *)&addr->src_addr)) {
                        rtrs_err((&srv_path->s),
                                 "Path (%s) with same addr exists (lifetime %u)\n",
                                 rtrs_srv_state_str(srv_path->state),
                                 (jiffies_to_msecs(jiffies - srv_path->connection_timeout)));
                        rtrs_srv_check_close_path(srv_path);
                        return true;
                }
        }

        return false;
}

static void free_path(struct rtrs_srv_path *srv_path)
{
        if (srv_path->kobj.state_in_sysfs) {
                kobject_del(&srv_path->kobj);
                kobject_put(&srv_path->kobj);
        } else {
                free_percpu(srv_path->stats->rdma_stats);
                kfree(srv_path->stats);
                kfree(srv_path);
        }
}

static void rtrs_srv_close_work(struct work_struct *work)
{
        struct rtrs_srv_path *srv_path;
        struct rtrs_srv_con *con;
        int i;

        srv_path = container_of(work, typeof(*srv_path), close_work);

        rtrs_srv_stop_hb(srv_path);

        for (i = 0; i < srv_path->s.con_num; i++) {
                if (!srv_path->s.con[i])
                        continue;
                con = to_srv_con(srv_path->s.con[i]);
                rdma_disconnect(con->c.cm_id);
                ib_drain_qp(con->c.qp);
        }

        /*
         * Degrade ref count to the usual model with a single shared
         * atomic_t counter
         */
        percpu_ref_kill(&srv_path->ids_inflight_ref);

        /* Wait for all completion */
        wait_for_completion(&srv_path->complete_done);

        rtrs_srv_destroy_path_files(srv_path);

        /* Notify upper layer if we are the last path */
        rtrs_srv_path_down(srv_path);

        unmap_cont_bufs(srv_path);
        rtrs_srv_free_ops_ids(srv_path);

        for (i = 0; i < srv_path->s.con_num; i++) {
                if (!srv_path->s.con[i])
                        continue;
                con = to_srv_con(srv_path->s.con[i]);
                rtrs_cq_qp_destroy(&con->c);
                rdma_destroy_id(con->c.cm_id);
                kfree(con);
        }
        rtrs_ib_dev_put(srv_path->s.dev);

        del_path_from_srv(srv_path);
        put_srv(srv_path->srv);
        srv_path->srv = NULL;
        rtrs_srv_change_state(srv_path, RTRS_SRV_CLOSED);

        kfree(srv_path->dma_addr);
        kfree(srv_path->s.con);
        free_path(srv_path);
}

static int rtrs_rdma_do_accept(struct rtrs_srv_path *srv_path,
                               struct rdma_cm_id *cm_id)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_msg_conn_rsp msg;
        struct rdma_conn_param param;
        int err;

        param = (struct rdma_conn_param) {
                .rnr_retry_count = 7,
                .private_data = &msg,
                .private_data_len = sizeof(msg),
        };

        msg = (struct rtrs_msg_conn_rsp) {
                .magic = cpu_to_le16(RTRS_MAGIC),
                .version = cpu_to_le16(RTRS_PROTO_VER),
                .queue_depth = cpu_to_le16(srv->queue_depth),
                .max_io_size = cpu_to_le32(max_chunk_size - MAX_HDR_SIZE),
                .max_hdr_size = cpu_to_le32(MAX_HDR_SIZE),
        };

        if (always_invalidate)
                msg.flags = cpu_to_le32(RTRS_MSG_NEW_RKEY_F);

        err = rdma_accept(cm_id, &param);
        if (err)
                pr_err("rdma_accept(), err: %pe\n", ERR_PTR(err));

        return err;
}

static int rtrs_rdma_do_reject(struct rdma_cm_id *cm_id, int errno)
{
        struct rtrs_msg_conn_rsp msg;
        int err;

        msg = (struct rtrs_msg_conn_rsp) {
                .magic = cpu_to_le16(RTRS_MAGIC),
                .version = cpu_to_le16(RTRS_PROTO_VER),
                .errno = cpu_to_le16(errno),
        };

        err = rdma_reject(cm_id, &msg, sizeof(msg), IB_CM_REJ_CONSUMER_DEFINED);
        if (err)
                pr_err("rdma_reject(), err: %pe\n", ERR_PTR(err));

        /* Bounce errno back */
        return errno;
}

static struct rtrs_srv_path *
__find_path(struct rtrs_srv_sess *srv, const uuid_t *sess_uuid)
{
        struct rtrs_srv_path *srv_path;

        list_for_each_entry(srv_path, &srv->paths_list, s.entry) {
                if (uuid_equal(&srv_path->s.uuid, sess_uuid))
                        return srv_path;
        }

        return NULL;
}

static int create_con(struct rtrs_srv_path *srv_path,
                      struct rdma_cm_id *cm_id,
                      unsigned int cid)
{
        struct rtrs_srv_sess *srv = srv_path->srv;
        struct rtrs_path *s = &srv_path->s;
        struct rtrs_srv_con *con;

        u32 cq_num, max_send_wr, max_recv_wr, wr_limit;
        int err, cq_vector;

        con = kzalloc_obj(*con);
        if (!con) {
                err = -ENOMEM;
                goto err;
        }

        spin_lock_init(&con->rsp_wr_wait_lock);
        INIT_LIST_HEAD(&con->rsp_wr_wait_list);
        con->c.cm_id = cm_id;
        con->c.path = &srv_path->s;
        con->c.cid = cid;
        atomic_set(&con->c.wr_cnt, 1);
        wr_limit = srv_path->s.dev->ib_dev->attrs.max_qp_wr;

        if (con->c.cid == 0) {
                /*
                 * All receive and all send (each requiring invalidate)
                 * + 2 for drain and heartbeat
                 */
                max_send_wr = min_t(int, wr_limit,
                                    SERVICE_CON_QUEUE_DEPTH * 2 + 2);
                max_recv_wr = max_send_wr;
                s->signal_interval = min_not_zero(srv->queue_depth,
                                                  (size_t)SERVICE_CON_QUEUE_DEPTH);
        } else {
                /* when always_invlaidate enalbed, we need linv+rinv+mr+imm */
                if (always_invalidate)
                        max_send_wr =
                                min_t(int, wr_limit,
                                      srv->queue_depth * (1 + 4) + 1);
                else
                        max_send_wr =
                                min_t(int, wr_limit,
                                      srv->queue_depth * (1 + 2) + 1);

                max_recv_wr = srv->queue_depth + 1;
        }
        cq_num = max_send_wr + max_recv_wr;
        atomic_set(&con->c.sq_wr_avail, max_send_wr);
        cq_vector = rtrs_srv_get_next_cq_vector(srv_path);

        /* TODO: SOFTIRQ can be faster, but be careful with softirq context */
        err = rtrs_cq_qp_create(&srv_path->s, &con->c, 1, cq_vector, cq_num,
                                 max_send_wr, max_recv_wr,
                                 IB_POLL_WORKQUEUE);
        if (err) {
                rtrs_err(s, "rtrs_cq_qp_create(), err: %pe\n", ERR_PTR(err));
                goto free_con;
        }
        if (con->c.cid == 0) {
                err = post_recv_info_req(con);
                if (err)
                        goto free_cqqp;
        }
        WARN_ON(srv_path->s.con[cid]);
        srv_path->s.con[cid] = &con->c;

        /*
         * Change context from server to current connection.  The other
         * way is to use cm_id->qp->qp_context, which does not work on OFED.
         */
        cm_id->context = &con->c;

        return 0;

free_cqqp:
        rtrs_cq_qp_destroy(&con->c);
free_con:
        kfree(con);

err:
        return err;
}

static struct rtrs_srv_path *__alloc_path(struct rtrs_srv_sess *srv,
                                           struct rdma_cm_id *cm_id,
                                           unsigned int con_num,
                                           unsigned int recon_cnt,
                                           const uuid_t *uuid)
{
        struct rtrs_srv_path *srv_path;
        int err = -ENOMEM;
        char str[NAME_MAX];
        struct rtrs_addr path;

        if (srv->paths_num >= MAX_PATHS_NUM) {
                err = -ECONNRESET;
                goto err;
        }
        if (__is_path_w_addr_exists(srv, &cm_id->route.addr)) {
                err = -EEXIST;
                goto err;
        }
        srv_path = kzalloc_obj(*srv_path);
        if (!srv_path)
                goto err;

        srv_path->stats = kzalloc_obj(*srv_path->stats);
        if (!srv_path->stats)
                goto err_free_sess;

        srv_path->stats->rdma_stats = alloc_percpu(struct rtrs_srv_stats_rdma_stats);
        if (!srv_path->stats->rdma_stats)
                goto err_free_stats;

        srv_path->stats->srv_path = srv_path;

        srv_path->dma_addr = kzalloc_objs(*srv_path->dma_addr, srv->queue_depth);
        if (!srv_path->dma_addr)
                goto err_free_percpu;

        srv_path->s.con = kzalloc_objs(*srv_path->s.con, con_num);
        if (!srv_path->s.con)
                goto err_free_dma_addr;

        srv_path->state = RTRS_SRV_CONNECTING;
        srv_path->srv = srv;
        srv_path->cur_cq_vector = -1;
        srv_path->s.dst_addr = cm_id->route.addr.dst_addr;
        srv_path->s.src_addr = cm_id->route.addr.src_addr;

        /* temporary until receiving session-name from client */
        path.src = &srv_path->s.src_addr;
        path.dst = &srv_path->s.dst_addr;
        rtrs_addr_to_str(&path, str, sizeof(str));
        strscpy(srv_path->s.sessname, str, sizeof(srv_path->s.sessname));

        srv_path->s.con_num = con_num;
        srv_path->s.irq_con_num = con_num;
        srv_path->s.recon_cnt = recon_cnt;
        uuid_copy(&srv_path->s.uuid, uuid);
        spin_lock_init(&srv_path->state_lock);
        INIT_WORK(&srv_path->close_work, rtrs_srv_close_work);
        rtrs_srv_init_hb(srv_path);
        srv_path->connection_timeout = 0;

        srv_path->s.dev = rtrs_ib_dev_find_or_add(cm_id->device, &dev_pd);
        if (!srv_path->s.dev) {
                err = -ENOMEM;
                goto err_free_con;
        }
        err = map_cont_bufs(srv_path);
        if (err)
                goto err_put_dev;

        err = rtrs_srv_alloc_ops_ids(srv_path);
        if (err)
                goto err_unmap_bufs;

        __add_path_to_srv(srv, srv_path);

        return srv_path;

err_unmap_bufs:
        unmap_cont_bufs(srv_path);
err_put_dev:
        rtrs_ib_dev_put(srv_path->s.dev);
err_free_con:
        kfree(srv_path->s.con);
err_free_dma_addr:
        kfree(srv_path->dma_addr);
err_free_percpu:
        free_percpu(srv_path->stats->rdma_stats);
err_free_stats:
        kfree(srv_path->stats);
err_free_sess:
        kfree(srv_path);
err:
        return ERR_PTR(err);
}

static int rtrs_rdma_connect(struct rdma_cm_id *cm_id,
                              const struct rtrs_msg_conn_req *msg,
                              size_t len)
{
        struct rtrs_srv_ctx *ctx = cm_id->context;
        struct rtrs_srv_path *srv_path;
        struct rtrs_srv_sess *srv;

        u16 version, con_num, cid;
        u16 recon_cnt;
        int err = -ECONNRESET;

        if (len < sizeof(*msg)) {
                pr_err("Invalid RTRS connection request\n");
                goto reject_w_err;
        }
        if (le16_to_cpu(msg->magic) != RTRS_MAGIC) {
                pr_err("Invalid RTRS magic\n");
                goto reject_w_err;
        }
        version = le16_to_cpu(msg->version);
        if (version >> 8 != RTRS_PROTO_VER_MAJOR) {
                pr_err("Unsupported major RTRS version: %d, expected %d\n",
                       version >> 8, RTRS_PROTO_VER_MAJOR);
                goto reject_w_err;
        }
        con_num = le16_to_cpu(msg->cid_num);
        if (con_num > 4096) {
                /* Sanity check */
                pr_err("Too many connections requested: %d\n", con_num);
                goto reject_w_err;
        }
        cid = le16_to_cpu(msg->cid);
        if (cid >= con_num) {
                /* Sanity check */
                pr_err("Incorrect cid: %d >= %d\n", cid, con_num);
                goto reject_w_err;
        }
        recon_cnt = le16_to_cpu(msg->recon_cnt);
        srv = get_or_create_srv(ctx, &msg->paths_uuid, msg->first_conn);
        if (IS_ERR(srv)) {
                err = PTR_ERR(srv);
                pr_err("get_or_create_srv(), error %d\n", err);
                goto reject_w_err;
        }
        mutex_lock(&srv->paths_mutex);
        srv_path = __find_path(srv, &msg->sess_uuid);
        if (srv_path) {
                struct rtrs_path *s = &srv_path->s;

                /* Session already holds a reference */
                put_srv(srv);

                if (srv_path->state != RTRS_SRV_CONNECTING) {
                        rtrs_err(s, "Session in wrong state: %s\n",
                                  rtrs_srv_state_str(srv_path->state));
                        mutex_unlock(&srv->paths_mutex);
                        goto reject_w_err;
                }
                /*
                 * Sanity checks
                 */
                if (con_num != s->con_num || cid >= s->con_num) {
                        rtrs_err(s, "Incorrect request: %d, %d\n",
                                  cid, con_num);
                        mutex_unlock(&srv->paths_mutex);
                        goto reject_w_err;
                }
                if (s->con[cid]) {
                        rtrs_err(s, "Connection (%s) already exists: %d (lifetime %u)\n",
                                 rtrs_srv_state_str(srv_path->state), cid,
                                 (jiffies_to_msecs(jiffies - srv_path->connection_timeout)));
                        rtrs_srv_check_close_path(srv_path);
                        mutex_unlock(&srv->paths_mutex);
                        goto reject_w_err;
                }
        } else {
                srv_path = __alloc_path(srv, cm_id, con_num, recon_cnt,
                                    &msg->sess_uuid);
                if (IS_ERR(srv_path)) {
                        mutex_unlock(&srv->paths_mutex);
                        put_srv(srv);
                        err = PTR_ERR(srv_path);
                        pr_err("RTRS server session allocation failed: %d\n", err);
                        goto reject_w_err;
                }
        }

        /*
         * Start of any connection creation resets the timeout for the path.
         */
        srv_path->connection_timeout = jiffies;

        err = create_con(srv_path, cm_id, cid);
        if (err) {
                rtrs_err((&srv_path->s), "create_con(), error %pe\n", ERR_PTR(err));
                rtrs_rdma_do_reject(cm_id, err);
                /*
                 * Since session has other connections we follow normal way
                 * through workqueue, but still return an error to tell cma.c
                 * to call rdma_destroy_id() for current connection.
                 */
                goto close_and_return_err;
        }
        err = rtrs_rdma_do_accept(srv_path, cm_id);
        if (err) {
                rtrs_err((&srv_path->s), "rtrs_rdma_do_accept(), error %pe\n",
                         ERR_PTR(err));
                rtrs_rdma_do_reject(cm_id, err);
                /*
                 * Since current connection was successfully added to the
                 * session we follow normal way through workqueue to close the
                 * session, thus return 0 to tell cma.c we call
                 * rdma_destroy_id() ourselves.
                 */
                err = 0;
                goto close_and_return_err;
        }
        mutex_unlock(&srv->paths_mutex);

        return 0;

reject_w_err:
        return rtrs_rdma_do_reject(cm_id, err);

close_and_return_err:
        mutex_unlock(&srv->paths_mutex);
        close_path(srv_path);

        return err;
}

static int rtrs_srv_rdma_cm_handler(struct rdma_cm_id *cm_id,
                                     struct rdma_cm_event *ev)
{
        struct rtrs_srv_path *srv_path = NULL;
        struct rtrs_path *s = NULL;
        struct rtrs_con *c = NULL;

        if (ev->event == RDMA_CM_EVENT_CONNECT_REQUEST)
                /*
                 * In case of error cma.c will destroy cm_id,
                 * see cma_process_remove()
                 */
                return rtrs_rdma_connect(cm_id, ev->param.conn.private_data,
                                          ev->param.conn.private_data_len);

        c = cm_id->context;
        s = c->path;
        srv_path = to_srv_path(s);

        switch (ev->event) {
        case RDMA_CM_EVENT_ESTABLISHED:
                /* Nothing here */
                break;
        case RDMA_CM_EVENT_REJECTED:
        case RDMA_CM_EVENT_CONNECT_ERROR:
        case RDMA_CM_EVENT_UNREACHABLE:
                if (ev->status < 0) {
                        rtrs_err(s, "CM error (CM event: %s, err: %pe)\n",
                                        rdma_event_msg(ev->event),
                                        ERR_PTR(ev->status));
                } else if (ev->status > 0) {
                        rtrs_err(s, "CM error (CM event: %s, err: %s)\n",
                                        rdma_event_msg(ev->event),
                                        rdma_reject_msg(cm_id, ev->status));
                }
                fallthrough;
        case RDMA_CM_EVENT_DISCONNECTED:
        case RDMA_CM_EVENT_ADDR_CHANGE:
        case RDMA_CM_EVENT_TIMEWAIT_EXIT:
        case RDMA_CM_EVENT_DEVICE_REMOVAL:
                close_path(srv_path);
                break;
        default:
                if (ev->status < 0) {
                        pr_err("Ignoring unexpected CM event %s, err %pe\n",
                                        rdma_event_msg(ev->event),
                                        ERR_PTR(ev->status));
                } else if (ev->status > 0) {
                        pr_err("Ignoring unexpected CM event %s, err %s\n",
                                        rdma_event_msg(ev->event),
                                        rdma_reject_msg(cm_id, ev->status));
                }
                break;
        }

        return 0;
}

static struct rdma_cm_id *rtrs_srv_cm_init(struct rtrs_srv_ctx *ctx,
                                            struct sockaddr *addr,
                                            enum rdma_ucm_port_space ps)
{
        struct rdma_cm_id *cm_id;
        int ret;

        cm_id = rdma_create_id(&init_net, rtrs_srv_rdma_cm_handler,
                               ctx, ps, IB_QPT_RC);
        if (IS_ERR(cm_id)) {
                ret = PTR_ERR(cm_id);
                pr_err("Creating id for RDMA connection failed, err: %d\n",
                       ret);
                goto err_out;
        }
        ret = rdma_bind_addr(cm_id, addr);
        if (ret) {
                pr_err("Binding RDMA address failed, err: %pe\n", ERR_PTR(ret));
                goto err_cm;
        }
        ret = rdma_listen(cm_id, 64);
        if (ret) {
                pr_err("Listening on RDMA connection failed, err: %pe\n",
                       ERR_PTR(ret));
                goto err_cm;
        }

        return cm_id;

err_cm:
        rdma_destroy_id(cm_id);
err_out:

        return ERR_PTR(ret);
}

static int rtrs_srv_rdma_init(struct rtrs_srv_ctx *ctx, u16 port)
{
        struct sockaddr_in6 sin = {
                .sin6_family    = AF_INET6,
                .sin6_addr      = IN6ADDR_ANY_INIT,
                .sin6_port      = htons(port),
        };
        struct sockaddr_ib sib = {
                .sib_family                     = AF_IB,
                .sib_sid        = cpu_to_be64(RDMA_IB_IP_PS_IB | port),
                .sib_sid_mask   = cpu_to_be64(0xffffffffffffffffULL),
                .sib_pkey       = cpu_to_be16(0xffff),
        };
        struct rdma_cm_id *cm_ip, *cm_ib;
        int ret;

        /*
         * We accept both IPoIB and IB connections, so we need to keep
         * two cm id's, one for each socket type and port space.
         * If the cm initialization of one of the id's fails, we abort
         * everything.
         */
        cm_ip = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sin, RDMA_PS_TCP);
        if (IS_ERR(cm_ip))
                return PTR_ERR(cm_ip);

        cm_ib = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sib, RDMA_PS_IB);
        if (IS_ERR(cm_ib)) {
                ret = PTR_ERR(cm_ib);
                goto free_cm_ip;
        }

        ctx->cm_id_ip = cm_ip;
        ctx->cm_id_ib = cm_ib;

        return 0;

free_cm_ip:
        rdma_destroy_id(cm_ip);

        return ret;
}

static struct rtrs_srv_ctx *alloc_srv_ctx(struct rtrs_srv_ops *ops)
{
        struct rtrs_srv_ctx *ctx;

        ctx = kzalloc_obj(*ctx);
        if (!ctx)
                return NULL;

        ctx->ops = *ops;
        mutex_init(&ctx->srv_mutex);
        INIT_LIST_HEAD(&ctx->srv_list);

        return ctx;
}

static void free_srv_ctx(struct rtrs_srv_ctx *ctx)
{
        WARN_ON(!list_empty(&ctx->srv_list));
        mutex_destroy(&ctx->srv_mutex);
        kfree(ctx);
}

static int rtrs_srv_add_one(struct ib_device *device)
{
        struct rtrs_srv_ctx *ctx;
        int ret = 0;

        mutex_lock(&ib_ctx.ib_dev_mutex);
        if (ib_ctx.ib_dev_count)
                goto out;

        /*
         * Since our CM IDs are NOT bound to any ib device we will create them
         * only once
         */
        ctx = ib_ctx.srv_ctx;
        ret = rtrs_srv_rdma_init(ctx, ib_ctx.port);
        if (ret) {
                /*
                 * We errored out here.
                 * According to the ib code, if we encounter an error here then the
                 * error code is ignored, and no more calls to our ops are made.
                 */
                pr_err("Failed to initialize RDMA connection");
                goto err_out;
        }

out:
        /*
         * Keep a track on the number of ib devices added
         */
        ib_ctx.ib_dev_count++;

err_out:
        mutex_unlock(&ib_ctx.ib_dev_mutex);
        return ret;
}

static void rtrs_srv_remove_one(struct ib_device *device, void *client_data)
{
        struct rtrs_srv_ctx *ctx;

        mutex_lock(&ib_ctx.ib_dev_mutex);
        ib_ctx.ib_dev_count--;

        if (ib_ctx.ib_dev_count)
                goto out;

        /*
         * Since our CM IDs are NOT bound to any ib device we will remove them
         * only once, when the last device is removed
         */
        ctx = ib_ctx.srv_ctx;
        rdma_destroy_id(ctx->cm_id_ip);
        rdma_destroy_id(ctx->cm_id_ib);

out:
        mutex_unlock(&ib_ctx.ib_dev_mutex);
}

static struct ib_client rtrs_srv_client = {
        .name   = "rtrs_server",
        .add    = rtrs_srv_add_one,
        .remove = rtrs_srv_remove_one
};

/**
 * rtrs_srv_open() - open RTRS server context
 * @ops:                callback functions
 * @port:               port to listen on
 *
 * Creates server context with specified callbacks.
 *
 * Return a valid pointer on success otherwise PTR_ERR.
 */
struct rtrs_srv_ctx *rtrs_srv_open(struct rtrs_srv_ops *ops, u16 port)
{
        struct rtrs_srv_ctx *ctx;
        int err;

        ctx = alloc_srv_ctx(ops);
        if (!ctx)
                return ERR_PTR(-ENOMEM);

        mutex_init(&ib_ctx.ib_dev_mutex);
        ib_ctx.srv_ctx = ctx;
        ib_ctx.port = port;

        err = ib_register_client(&rtrs_srv_client);
        if (err) {
                free_srv_ctx(ctx);
                return ERR_PTR(err);
        }

        return ctx;
}
EXPORT_SYMBOL(rtrs_srv_open);

static void close_paths(struct rtrs_srv_sess *srv)
{
        struct rtrs_srv_path *srv_path;

        mutex_lock(&srv->paths_mutex);
        list_for_each_entry(srv_path, &srv->paths_list, s.entry)
                close_path(srv_path);
        mutex_unlock(&srv->paths_mutex);
}

static void close_ctx(struct rtrs_srv_ctx *ctx)
{
        struct rtrs_srv_sess *srv;

        mutex_lock(&ctx->srv_mutex);
        list_for_each_entry(srv, &ctx->srv_list, ctx_list)
                close_paths(srv);
        mutex_unlock(&ctx->srv_mutex);
        flush_workqueue(rtrs_wq);
}

/**
 * rtrs_srv_close() - close RTRS server context
 * @ctx: pointer to server context
 *
 * Closes RTRS server context with all client sessions.
 */
void rtrs_srv_close(struct rtrs_srv_ctx *ctx)
{
        ib_unregister_client(&rtrs_srv_client);
        mutex_destroy(&ib_ctx.ib_dev_mutex);
        close_ctx(ctx);
        free_srv_ctx(ctx);
}
EXPORT_SYMBOL(rtrs_srv_close);

static int check_module_params(void)
{
        if (sess_queue_depth < 1 || sess_queue_depth > MAX_SESS_QUEUE_DEPTH) {
                pr_err("Invalid sess_queue_depth value %d, has to be >= %d, <= %d.\n",
                       sess_queue_depth, 1, MAX_SESS_QUEUE_DEPTH);
                return -EINVAL;
        }
        if (max_chunk_size < MIN_CHUNK_SIZE || !is_power_of_2(max_chunk_size)) {
                pr_err("Invalid max_chunk_size value %d, has to be >= %d and should be power of two.\n",
                       max_chunk_size, MIN_CHUNK_SIZE);
                return -EINVAL;
        }

        /*
         * Check if IB immediate data size is enough to hold the mem_id and the
         * offset inside the memory chunk
         */
        if ((ilog2(sess_queue_depth - 1) + 1) +
            (ilog2(max_chunk_size - 1) + 1) > MAX_IMM_PAYL_BITS) {
                pr_err("RDMA immediate size (%db) not enough to encode %d buffers of size %dB. Reduce 'sess_queue_depth' or 'max_chunk_size' parameters.\n",
                       MAX_IMM_PAYL_BITS, sess_queue_depth, max_chunk_size);
                return -EINVAL;
        }

        return 0;
}

void rtrs_srv_ib_event_handler(struct ib_event_handler *handler,
                               struct ib_event *ibevent)
{
        struct ib_device *idev = ibevent->device;
        u32 port_num = ibevent->element.port_num;

        pr_info("Handling event: %s (%d). HCA name: %s, port num: %u\n",
                        ib_event_msg(ibevent->event), ibevent->event, idev->name, port_num);
}

static int rtrs_srv_ib_dev_init(struct rtrs_ib_dev *dev)
{
        INIT_IB_EVENT_HANDLER(&dev->event_handler, dev->ib_dev,
                              rtrs_srv_ib_event_handler);
        ib_register_event_handler(&dev->event_handler);

        return 0;
}

static void rtrs_srv_ib_dev_deinit(struct rtrs_ib_dev *dev)
{
        ib_unregister_event_handler(&dev->event_handler);
}


static const struct rtrs_rdma_dev_pd_ops dev_pd_ops = {
        .init = rtrs_srv_ib_dev_init,
        .deinit = rtrs_srv_ib_dev_deinit
};


static int __init rtrs_server_init(void)
{
        int err;

        pr_info("Loading module %s, proto %s: (max_chunk_size: %d (pure IO %ld, headers %ld) , sess_queue_depth: %d, always_invalidate: %d)\n",
                KBUILD_MODNAME, RTRS_PROTO_VER_STRING,
                max_chunk_size, max_chunk_size - MAX_HDR_SIZE, MAX_HDR_SIZE,
                sess_queue_depth, always_invalidate);

        rtrs_rdma_dev_pd_init(0, &dev_pd);

        err = check_module_params();
        if (err) {
                pr_err("Failed to load module, invalid module parameters, err: %pe\n",
                       ERR_PTR(err));
                return err;
        }
        err = class_register(&rtrs_dev_class);
        if (err)
                goto out_err;

        rtrs_wq = alloc_workqueue("rtrs_server_wq", 0, 0);
        if (!rtrs_wq) {
                err = -ENOMEM;
                goto out_dev_class;
        }

        return 0;

out_dev_class:
        class_unregister(&rtrs_dev_class);
out_err:
        return err;
}

static void __exit rtrs_server_exit(void)
{
        destroy_workqueue(rtrs_wq);
        class_unregister(&rtrs_dev_class);
        rtrs_rdma_dev_pd_deinit(&dev_pd);
}

module_init(rtrs_server_init);
module_exit(rtrs_server_exit);