root/net/devlink/rate.c
// SPDX-License-Identifier: GPL-2.0-or-later
/*
 * Copyright (c) 2016 Mellanox Technologies. All rights reserved.
 * Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
 */

#include "devl_internal.h"

static inline bool
devlink_rate_is_leaf(struct devlink_rate *devlink_rate)
{
        return devlink_rate->type == DEVLINK_RATE_TYPE_LEAF;
}

bool devlink_rate_is_node(const struct devlink_rate *devlink_rate)
{
        return devlink_rate->type == DEVLINK_RATE_TYPE_NODE;
}

static struct devlink_rate *
devlink_rate_leaf_get_from_info(struct devlink *devlink, struct genl_info *info)
{
        struct devlink_rate *devlink_rate;
        struct devlink_port *devlink_port;

        devlink_port = devlink_port_get_from_attrs(devlink, info->attrs);
        if (IS_ERR(devlink_port))
                return ERR_CAST(devlink_port);
        devlink_rate = devlink_port->devlink_rate;
        return devlink_rate ?: ERR_PTR(-ENODEV);
}

static struct devlink_rate *
devlink_rate_node_get_by_name(struct devlink *devlink, const char *node_name)
{
        struct devlink_rate *devlink_rate;

        list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
                if (devlink_rate_is_node(devlink_rate) &&
                    !strcmp(node_name, devlink_rate->name))
                        return devlink_rate;
        }
        return ERR_PTR(-ENODEV);
}

static struct devlink_rate *
devlink_rate_node_get_from_attrs(struct devlink *devlink, struct nlattr **attrs)
{
        const char *rate_node_name;
        size_t len;

        if (!attrs[DEVLINK_ATTR_RATE_NODE_NAME])
                return ERR_PTR(-EINVAL);
        rate_node_name = nla_data(attrs[DEVLINK_ATTR_RATE_NODE_NAME]);
        len = strlen(rate_node_name);
        /* Name cannot be empty or decimal number */
        if (!len || strspn(rate_node_name, "0123456789") == len)
                return ERR_PTR(-EINVAL);

        return devlink_rate_node_get_by_name(devlink, rate_node_name);
}

static struct devlink_rate *
devlink_rate_node_get_from_info(struct devlink *devlink, struct genl_info *info)
{
        return devlink_rate_node_get_from_attrs(devlink, info->attrs);
}

static struct devlink_rate *
devlink_rate_get_from_info(struct devlink *devlink, struct genl_info *info)
{
        struct nlattr **attrs = info->attrs;

        if (attrs[DEVLINK_ATTR_PORT_INDEX])
                return devlink_rate_leaf_get_from_info(devlink, info);
        else if (attrs[DEVLINK_ATTR_RATE_NODE_NAME])
                return devlink_rate_node_get_from_info(devlink, info);
        else
                return ERR_PTR(-EINVAL);
}

static int devlink_rate_put_tc_bws(struct sk_buff *msg, u32 *tc_bw)
{
        struct nlattr *nla_tc_bw;
        int i;

        for (i = 0; i < DEVLINK_RATE_TCS_MAX; i++) {
                nla_tc_bw = nla_nest_start(msg, DEVLINK_ATTR_RATE_TC_BWS);
                if (!nla_tc_bw)
                        return -EMSGSIZE;

                if (nla_put_u8(msg, DEVLINK_RATE_TC_ATTR_INDEX, i) ||
                    nla_put_u32(msg, DEVLINK_RATE_TC_ATTR_BW, tc_bw[i]))
                        goto nla_put_failure;

                nla_nest_end(msg, nla_tc_bw);
        }
        return 0;

nla_put_failure:
        nla_nest_cancel(msg, nla_tc_bw);
        return -EMSGSIZE;
}

static int devlink_nl_rate_fill(struct sk_buff *msg,
                                struct devlink_rate *devlink_rate,
                                enum devlink_command cmd, u32 portid, u32 seq,
                                int flags, struct netlink_ext_ack *extack)
{
        struct devlink *devlink = devlink_rate->devlink;
        void *hdr;

        hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd);
        if (!hdr)
                return -EMSGSIZE;

        if (devlink_nl_put_handle(msg, devlink))
                goto nla_put_failure;

        if (nla_put_u16(msg, DEVLINK_ATTR_RATE_TYPE, devlink_rate->type))
                goto nla_put_failure;

        if (devlink_rate_is_leaf(devlink_rate)) {
                if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX,
                                devlink_rate->devlink_port->index))
                        goto nla_put_failure;
        } else if (devlink_rate_is_node(devlink_rate)) {
                if (nla_put_string(msg, DEVLINK_ATTR_RATE_NODE_NAME,
                                   devlink_rate->name))
                        goto nla_put_failure;
        }

        if (devlink_nl_put_u64(msg, DEVLINK_ATTR_RATE_TX_SHARE,
                               devlink_rate->tx_share))
                goto nla_put_failure;

        if (devlink_nl_put_u64(msg, DEVLINK_ATTR_RATE_TX_MAX,
                               devlink_rate->tx_max))
                goto nla_put_failure;

        if (nla_put_u32(msg, DEVLINK_ATTR_RATE_TX_PRIORITY,
                        devlink_rate->tx_priority))
                goto nla_put_failure;

        if (nla_put_u32(msg, DEVLINK_ATTR_RATE_TX_WEIGHT,
                        devlink_rate->tx_weight))
                goto nla_put_failure;

        if (devlink_rate->parent)
                if (nla_put_string(msg, DEVLINK_ATTR_RATE_PARENT_NODE_NAME,
                                   devlink_rate->parent->name))
                        goto nla_put_failure;

        if (devlink_rate_put_tc_bws(msg, devlink_rate->tc_bw))
                goto nla_put_failure;

        genlmsg_end(msg, hdr);
        return 0;

nla_put_failure:
        genlmsg_cancel(msg, hdr);
        return -EMSGSIZE;
}

static void devlink_rate_notify(struct devlink_rate *devlink_rate,
                                enum devlink_command cmd)
{
        struct devlink *devlink = devlink_rate->devlink;
        struct sk_buff *msg;
        int err;

        WARN_ON(cmd != DEVLINK_CMD_RATE_NEW && cmd != DEVLINK_CMD_RATE_DEL);

        if (!devl_is_registered(devlink) || !devlink_nl_notify_need(devlink))
                return;

        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!msg)
                return;

        err = devlink_nl_rate_fill(msg, devlink_rate, cmd, 0, 0, 0, NULL);
        if (err) {
                nlmsg_free(msg);
                return;
        }

        devlink_nl_notify_send(devlink, msg);
}

void devlink_rates_notify_register(struct devlink *devlink)
{
        struct devlink_rate *rate_node;

        list_for_each_entry(rate_node, &devlink->rate_list, list)
                devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_NEW);
}

void devlink_rates_notify_unregister(struct devlink *devlink)
{
        struct devlink_rate *rate_node;

        list_for_each_entry_reverse(rate_node, &devlink->rate_list, list)
                devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_DEL);
}

static int
devlink_nl_rate_get_dump_one(struct sk_buff *msg, struct devlink *devlink,
                             struct netlink_callback *cb, int flags)
{
        struct devlink_nl_dump_state *state = devlink_dump_state(cb);
        struct devlink_rate *devlink_rate;
        int idx = 0;
        int err = 0;

        list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
                enum devlink_command cmd = DEVLINK_CMD_RATE_NEW;
                u32 id = NETLINK_CB(cb->skb).portid;

                if (idx < state->idx) {
                        idx++;
                        continue;
                }
                err = devlink_nl_rate_fill(msg, devlink_rate, cmd, id,
                                           cb->nlh->nlmsg_seq, flags, NULL);
                if (err) {
                        state->idx = idx;
                        break;
                }
                idx++;
        }

        return err;
}

int devlink_nl_rate_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb)
{
        return devlink_nl_dumpit(skb, cb, devlink_nl_rate_get_dump_one);
}

int devlink_nl_rate_get_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct devlink *devlink = info->user_ptr[0];
        struct devlink_rate *devlink_rate;
        struct sk_buff *msg;
        int err;

        devlink_rate = devlink_rate_get_from_info(devlink, info);
        if (IS_ERR(devlink_rate))
                return PTR_ERR(devlink_rate);

        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!msg)
                return -ENOMEM;

        err = devlink_nl_rate_fill(msg, devlink_rate, DEVLINK_CMD_RATE_NEW,
                                   info->snd_portid, info->snd_seq, 0,
                                   info->extack);
        if (err) {
                nlmsg_free(msg);
                return err;
        }

        return genlmsg_reply(msg, info);
}

static bool
devlink_rate_is_parent_node(struct devlink_rate *devlink_rate,
                            struct devlink_rate *parent)
{
        while (parent) {
                if (parent == devlink_rate)
                        return true;
                parent = parent->parent;
        }
        return false;
}

static int
devlink_nl_rate_parent_node_set(struct devlink_rate *devlink_rate,
                                struct genl_info *info,
                                struct nlattr *nla_parent)
{
        struct devlink *devlink = devlink_rate->devlink;
        const char *parent_name = nla_data(nla_parent);
        const struct devlink_ops *ops = devlink->ops;
        size_t len = strlen(parent_name);
        struct devlink_rate *parent;
        int err = -EOPNOTSUPP;

        parent = devlink_rate->parent;

        if (parent && !len) {
                if (devlink_rate_is_leaf(devlink_rate))
                        err = ops->rate_leaf_parent_set(devlink_rate, NULL,
                                                        devlink_rate->priv, NULL,
                                                        info->extack);
                else if (devlink_rate_is_node(devlink_rate))
                        err = ops->rate_node_parent_set(devlink_rate, NULL,
                                                        devlink_rate->priv, NULL,
                                                        info->extack);
                if (err)
                        return err;

                refcount_dec(&parent->refcnt);
                devlink_rate->parent = NULL;
        } else if (len) {
                parent = devlink_rate_node_get_by_name(devlink, parent_name);
                if (IS_ERR(parent))
                        return -ENODEV;

                if (parent == devlink_rate) {
                        NL_SET_ERR_MSG(info->extack, "Parent to self is not allowed");
                        return -EINVAL;
                }

                if (devlink_rate_is_node(devlink_rate) &&
                    devlink_rate_is_parent_node(devlink_rate, parent->parent)) {
                        NL_SET_ERR_MSG(info->extack, "Node is already a parent of parent node.");
                        return -EEXIST;
                }

                if (devlink_rate_is_leaf(devlink_rate))
                        err = ops->rate_leaf_parent_set(devlink_rate, parent,
                                                        devlink_rate->priv, parent->priv,
                                                        info->extack);
                else if (devlink_rate_is_node(devlink_rate))
                        err = ops->rate_node_parent_set(devlink_rate, parent,
                                                        devlink_rate->priv, parent->priv,
                                                        info->extack);
                if (err)
                        return err;

                if (devlink_rate->parent)
                        /* we're reassigning to other parent in this case */
                        refcount_dec(&devlink_rate->parent->refcnt);

                refcount_inc(&parent->refcnt);
                devlink_rate->parent = parent;
        }

        return 0;
}

static int devlink_nl_rate_tc_bw_parse(struct nlattr *parent_nest, u32 *tc_bw,
                                       unsigned long *bitmap,
                                       struct netlink_ext_ack *extack)
{
        struct nlattr *tb[DEVLINK_RATE_TC_ATTR_MAX + 1];
        u8 tc_index;
        int err;

        err = nla_parse_nested(tb, DEVLINK_RATE_TC_ATTR_MAX, parent_nest,
                               devlink_dl_rate_tc_bws_nl_policy, extack);
        if (err)
                return err;

        if (!tb[DEVLINK_RATE_TC_ATTR_INDEX]) {
                NL_SET_ERR_ATTR_MISS(extack, parent_nest,
                                     DEVLINK_RATE_TC_ATTR_INDEX);
                return -EINVAL;
        }

        tc_index = nla_get_u8(tb[DEVLINK_RATE_TC_ATTR_INDEX]);

        if (!tb[DEVLINK_RATE_TC_ATTR_BW]) {
                NL_SET_ERR_ATTR_MISS(extack, parent_nest,
                                     DEVLINK_RATE_TC_ATTR_BW);
                return -EINVAL;
        }

        if (test_and_set_bit(tc_index, bitmap)) {
                NL_SET_ERR_MSG_FMT(extack,
                                   "Duplicate traffic class index specified (%u)",
                                   tc_index);
                return -EINVAL;
        }

        tc_bw[tc_index] = nla_get_u32(tb[DEVLINK_RATE_TC_ATTR_BW]);

        return 0;
}

static int devlink_nl_rate_tc_bw_set(struct devlink_rate *devlink_rate,
                                     struct genl_info *info)
{
        DECLARE_BITMAP(bitmap, DEVLINK_RATE_TCS_MAX) = {};
        struct devlink *devlink = devlink_rate->devlink;
        const struct devlink_ops *ops = devlink->ops;
        u32 tc_bw[DEVLINK_RATE_TCS_MAX] = {};
        int rem, err = -EOPNOTSUPP, i;
        struct nlattr *attr;

        nlmsg_for_each_attr_type(attr, DEVLINK_ATTR_RATE_TC_BWS, info->nlhdr,
                                 GENL_HDRLEN, rem) {
                err = devlink_nl_rate_tc_bw_parse(attr, tc_bw, bitmap,
                                                  info->extack);
                if (err)
                        return err;
        }

        for (i = 0; i < DEVLINK_RATE_TCS_MAX; i++) {
                if (!test_bit(i, bitmap)) {
                        NL_SET_ERR_MSG_FMT(info->extack,
                                           "Bandwidth values must be specified for all %u traffic classes",
                                           DEVLINK_RATE_TCS_MAX);
                        return -EINVAL;
                }
        }

        if (devlink_rate_is_leaf(devlink_rate))
                err = ops->rate_leaf_tc_bw_set(devlink_rate, devlink_rate->priv,
                                               tc_bw, info->extack);
        else if (devlink_rate_is_node(devlink_rate))
                err = ops->rate_node_tc_bw_set(devlink_rate, devlink_rate->priv,
                                               tc_bw, info->extack);

        if (err)
                return err;

        memcpy(devlink_rate->tc_bw, tc_bw, sizeof(tc_bw));

        return 0;
}

static int devlink_nl_rate_set(struct devlink_rate *devlink_rate,
                               const struct devlink_ops *ops,
                               struct genl_info *info)
{
        struct nlattr *nla_parent, **attrs = info->attrs;
        int err = -EOPNOTSUPP;
        u32 priority;
        u32 weight;
        u64 rate;

        if (attrs[DEVLINK_ATTR_RATE_TX_SHARE]) {
                rate = nla_get_u64(attrs[DEVLINK_ATTR_RATE_TX_SHARE]);
                if (devlink_rate_is_leaf(devlink_rate))
                        err = ops->rate_leaf_tx_share_set(devlink_rate, devlink_rate->priv,
                                                          rate, info->extack);
                else if (devlink_rate_is_node(devlink_rate))
                        err = ops->rate_node_tx_share_set(devlink_rate, devlink_rate->priv,
                                                          rate, info->extack);
                if (err)
                        return err;
                devlink_rate->tx_share = rate;
        }

        if (attrs[DEVLINK_ATTR_RATE_TX_MAX]) {
                rate = nla_get_u64(attrs[DEVLINK_ATTR_RATE_TX_MAX]);
                if (devlink_rate_is_leaf(devlink_rate))
                        err = ops->rate_leaf_tx_max_set(devlink_rate, devlink_rate->priv,
                                                        rate, info->extack);
                else if (devlink_rate_is_node(devlink_rate))
                        err = ops->rate_node_tx_max_set(devlink_rate, devlink_rate->priv,
                                                        rate, info->extack);
                if (err)
                        return err;
                devlink_rate->tx_max = rate;
        }

        if (attrs[DEVLINK_ATTR_RATE_TX_PRIORITY]) {
                priority = nla_get_u32(attrs[DEVLINK_ATTR_RATE_TX_PRIORITY]);
                if (devlink_rate_is_leaf(devlink_rate))
                        err = ops->rate_leaf_tx_priority_set(devlink_rate, devlink_rate->priv,
                                                             priority, info->extack);
                else if (devlink_rate_is_node(devlink_rate))
                        err = ops->rate_node_tx_priority_set(devlink_rate, devlink_rate->priv,
                                                             priority, info->extack);

                if (err)
                        return err;
                devlink_rate->tx_priority = priority;
        }

        if (attrs[DEVLINK_ATTR_RATE_TX_WEIGHT]) {
                weight = nla_get_u32(attrs[DEVLINK_ATTR_RATE_TX_WEIGHT]);
                if (devlink_rate_is_leaf(devlink_rate))
                        err = ops->rate_leaf_tx_weight_set(devlink_rate, devlink_rate->priv,
                                                           weight, info->extack);
                else if (devlink_rate_is_node(devlink_rate))
                        err = ops->rate_node_tx_weight_set(devlink_rate, devlink_rate->priv,
                                                           weight, info->extack);

                if (err)
                        return err;
                devlink_rate->tx_weight = weight;
        }

        nla_parent = attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME];
        if (nla_parent) {
                err = devlink_nl_rate_parent_node_set(devlink_rate, info,
                                                      nla_parent);
                if (err)
                        return err;
        }

        if (attrs[DEVLINK_ATTR_RATE_TC_BWS]) {
                err = devlink_nl_rate_tc_bw_set(devlink_rate, info);
                if (err)
                        return err;
        }

        return 0;
}

static bool devlink_rate_set_ops_supported(const struct devlink_ops *ops,
                                           struct genl_info *info,
                                           enum devlink_rate_type type)
{
        struct nlattr **attrs = info->attrs;

        if (type == DEVLINK_RATE_TYPE_LEAF) {
                if (attrs[DEVLINK_ATTR_RATE_TX_SHARE] && !ops->rate_leaf_tx_share_set) {
                        NL_SET_ERR_MSG(info->extack, "TX share set isn't supported for the leafs");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TX_MAX] && !ops->rate_leaf_tx_max_set) {
                        NL_SET_ERR_MSG(info->extack, "TX max set isn't supported for the leafs");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME] &&
                    !ops->rate_leaf_parent_set) {
                        NL_SET_ERR_MSG(info->extack, "Parent set isn't supported for the leafs");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TX_PRIORITY] && !ops->rate_leaf_tx_priority_set) {
                        NL_SET_ERR_MSG_ATTR(info->extack,
                                            attrs[DEVLINK_ATTR_RATE_TX_PRIORITY],
                                            "TX priority set isn't supported for the leafs");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TX_WEIGHT] && !ops->rate_leaf_tx_weight_set) {
                        NL_SET_ERR_MSG_ATTR(info->extack,
                                            attrs[DEVLINK_ATTR_RATE_TX_WEIGHT],
                                            "TX weight set isn't supported for the leafs");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TC_BWS] &&
                    !ops->rate_leaf_tc_bw_set) {
                        NL_SET_ERR_MSG_ATTR(info->extack,
                                            attrs[DEVLINK_ATTR_RATE_TC_BWS],
                                            "TC bandwidth set isn't supported for the leafs");
                        return false;
                }
        } else if (type == DEVLINK_RATE_TYPE_NODE) {
                if (attrs[DEVLINK_ATTR_RATE_TX_SHARE] && !ops->rate_node_tx_share_set) {
                        NL_SET_ERR_MSG(info->extack, "TX share set isn't supported for the nodes");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TX_MAX] && !ops->rate_node_tx_max_set) {
                        NL_SET_ERR_MSG(info->extack, "TX max set isn't supported for the nodes");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME] &&
                    !ops->rate_node_parent_set) {
                        NL_SET_ERR_MSG(info->extack, "Parent set isn't supported for the nodes");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TX_PRIORITY] && !ops->rate_node_tx_priority_set) {
                        NL_SET_ERR_MSG_ATTR(info->extack,
                                            attrs[DEVLINK_ATTR_RATE_TX_PRIORITY],
                                            "TX priority set isn't supported for the nodes");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TX_WEIGHT] && !ops->rate_node_tx_weight_set) {
                        NL_SET_ERR_MSG_ATTR(info->extack,
                                            attrs[DEVLINK_ATTR_RATE_TX_WEIGHT],
                                            "TX weight set isn't supported for the nodes");
                        return false;
                }
                if (attrs[DEVLINK_ATTR_RATE_TC_BWS] &&
                    !ops->rate_node_tc_bw_set) {
                        NL_SET_ERR_MSG_ATTR(info->extack,
                                            attrs[DEVLINK_ATTR_RATE_TC_BWS],
                                            "TC bandwidth set isn't supported for the nodes");
                        return false;
                }
        } else {
                WARN(1, "Unknown type of rate object");
                return false;
        }

        return true;
}

int devlink_nl_rate_set_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct devlink *devlink = info->user_ptr[0];
        struct devlink_rate *devlink_rate;
        const struct devlink_ops *ops;
        int err;

        devlink_rate = devlink_rate_get_from_info(devlink, info);
        if (IS_ERR(devlink_rate))
                return PTR_ERR(devlink_rate);

        ops = devlink->ops;
        if (!ops || !devlink_rate_set_ops_supported(ops, info, devlink_rate->type))
                return -EOPNOTSUPP;

        err = devlink_nl_rate_set(devlink_rate, ops, info);

        if (!err)
                devlink_rate_notify(devlink_rate, DEVLINK_CMD_RATE_NEW);
        return err;
}

int devlink_nl_rate_new_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct devlink *devlink = info->user_ptr[0];
        struct devlink_rate *rate_node;
        const struct devlink_ops *ops;
        int err;

        ops = devlink->ops;
        if (!ops || !ops->rate_node_new || !ops->rate_node_del) {
                NL_SET_ERR_MSG(info->extack, "Rate nodes aren't supported");
                return -EOPNOTSUPP;
        }

        if (!devlink_rate_set_ops_supported(ops, info, DEVLINK_RATE_TYPE_NODE))
                return -EOPNOTSUPP;

        rate_node = devlink_rate_node_get_from_attrs(devlink, info->attrs);
        if (!IS_ERR(rate_node))
                return -EEXIST;
        else if (rate_node == ERR_PTR(-EINVAL))
                return -EINVAL;

        rate_node = kzalloc_obj(*rate_node);
        if (!rate_node)
                return -ENOMEM;

        rate_node->devlink = devlink;
        rate_node->type = DEVLINK_RATE_TYPE_NODE;
        rate_node->name = nla_strdup(info->attrs[DEVLINK_ATTR_RATE_NODE_NAME], GFP_KERNEL);
        if (!rate_node->name) {
                err = -ENOMEM;
                goto err_strdup;
        }

        err = ops->rate_node_new(rate_node, &rate_node->priv, info->extack);
        if (err)
                goto err_node_new;

        err = devlink_nl_rate_set(rate_node, ops, info);
        if (err)
                goto err_rate_set;

        refcount_set(&rate_node->refcnt, 1);
        list_add(&rate_node->list, &devlink->rate_list);
        devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_NEW);
        return 0;

err_rate_set:
        ops->rate_node_del(rate_node, rate_node->priv, info->extack);
err_node_new:
        kfree(rate_node->name);
err_strdup:
        kfree(rate_node);
        return err;
}

int devlink_nl_rate_del_doit(struct sk_buff *skb, struct genl_info *info)
{
        struct devlink *devlink = info->user_ptr[0];
        struct devlink_rate *rate_node;
        int err;

        rate_node = devlink_rate_node_get_from_info(devlink, info);
        if (IS_ERR(rate_node))
                return PTR_ERR(rate_node);

        if (refcount_read(&rate_node->refcnt) > 1) {
                NL_SET_ERR_MSG(info->extack, "Node has children. Cannot delete node.");
                return -EBUSY;
        }

        devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_DEL);
        err = devlink->ops->rate_node_del(rate_node, rate_node->priv,
                                          info->extack);
        if (rate_node->parent)
                refcount_dec(&rate_node->parent->refcnt);
        list_del(&rate_node->list);
        kfree(rate_node->name);
        kfree(rate_node);
        return err;
}

int devlink_rates_check(struct devlink *devlink,
                        bool (*rate_filter)(const struct devlink_rate *),
                        struct netlink_ext_ack *extack)
{
        struct devlink_rate *devlink_rate;

        list_for_each_entry(devlink_rate, &devlink->rate_list, list)
                if (!rate_filter || rate_filter(devlink_rate)) {
                        if (extack)
                                NL_SET_ERR_MSG(extack, "Rate node(s) exists.");
                        return -EBUSY;
                }
        return 0;
}

/**
 * devl_rate_node_create - create devlink rate node
 * @devlink: devlink instance
 * @priv: driver private data
 * @node_name: name of the resulting node
 * @parent: parent devlink_rate struct
 *
 * Create devlink rate object of type node
 */
struct devlink_rate *
devl_rate_node_create(struct devlink *devlink, void *priv, char *node_name,
                      struct devlink_rate *parent)
{
        struct devlink_rate *rate_node;

        rate_node = devlink_rate_node_get_by_name(devlink, node_name);
        if (!IS_ERR(rate_node))
                return ERR_PTR(-EEXIST);

        rate_node = kzalloc_obj(*rate_node);
        if (!rate_node)
                return ERR_PTR(-ENOMEM);

        if (parent) {
                rate_node->parent = parent;
                refcount_inc(&rate_node->parent->refcnt);
        }

        rate_node->type = DEVLINK_RATE_TYPE_NODE;
        rate_node->devlink = devlink;
        rate_node->priv = priv;

        rate_node->name = kstrdup(node_name, GFP_KERNEL);
        if (!rate_node->name) {
                kfree(rate_node);
                return ERR_PTR(-ENOMEM);
        }

        refcount_set(&rate_node->refcnt, 1);
        list_add(&rate_node->list, &devlink->rate_list);
        devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_NEW);
        return rate_node;
}
EXPORT_SYMBOL_GPL(devl_rate_node_create);

/**
 * devl_rate_leaf_create - create devlink rate leaf
 * @devlink_port: devlink port object to create rate object on
 * @priv: driver private data
 * @parent: parent devlink_rate struct
 *
 * Create devlink rate object of type leaf on provided @devlink_port.
 */
int devl_rate_leaf_create(struct devlink_port *devlink_port, void *priv,
                          struct devlink_rate *parent)
{
        struct devlink *devlink = devlink_port->devlink;
        struct devlink_rate *devlink_rate;

        devl_assert_locked(devlink_port->devlink);

        if (WARN_ON(devlink_port->devlink_rate))
                return -EBUSY;

        devlink_rate = kzalloc_obj(*devlink_rate);
        if (!devlink_rate)
                return -ENOMEM;

        if (parent) {
                devlink_rate->parent = parent;
                refcount_inc(&devlink_rate->parent->refcnt);
        }

        devlink_rate->type = DEVLINK_RATE_TYPE_LEAF;
        devlink_rate->devlink = devlink;
        devlink_rate->devlink_port = devlink_port;
        devlink_rate->priv = priv;
        list_add_tail(&devlink_rate->list, &devlink->rate_list);
        devlink_port->devlink_rate = devlink_rate;
        devlink_rate_notify(devlink_rate, DEVLINK_CMD_RATE_NEW);

        return 0;
}
EXPORT_SYMBOL_GPL(devl_rate_leaf_create);

/**
 * devl_rate_leaf_destroy - destroy devlink rate leaf
 *
 * @devlink_port: devlink port linked to the rate object
 *
 * Destroy the devlink rate object of type leaf on provided @devlink_port.
 */
void devl_rate_leaf_destroy(struct devlink_port *devlink_port)
{
        struct devlink_rate *devlink_rate = devlink_port->devlink_rate;

        devl_assert_locked(devlink_port->devlink);
        if (!devlink_rate)
                return;

        devlink_rate_notify(devlink_rate, DEVLINK_CMD_RATE_DEL);
        if (devlink_rate->parent)
                refcount_dec(&devlink_rate->parent->refcnt);
        list_del(&devlink_rate->list);
        devlink_port->devlink_rate = NULL;
        kfree(devlink_rate);
}
EXPORT_SYMBOL_GPL(devl_rate_leaf_destroy);

/**
 * devl_rate_nodes_destroy - destroy all devlink rate nodes on device
 * @devlink: devlink instance
 *
 * Unset parent for all rate objects and destroy all rate nodes
 * on specified device.
 */
void devl_rate_nodes_destroy(struct devlink *devlink)
{
        const struct devlink_ops *ops = devlink->ops;
        struct devlink_rate *devlink_rate, *tmp;

        devl_assert_locked(devlink);

        list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
                if (!devlink_rate->parent)
                        continue;

                if (devlink_rate_is_leaf(devlink_rate))
                        ops->rate_leaf_parent_set(devlink_rate, NULL, devlink_rate->priv,
                                                  NULL, NULL);
                else if (devlink_rate_is_node(devlink_rate))
                        ops->rate_node_parent_set(devlink_rate, NULL, devlink_rate->priv,
                                                  NULL, NULL);

                refcount_dec(&devlink_rate->parent->refcnt);
                devlink_rate->parent = NULL;
        }
        list_for_each_entry_safe(devlink_rate, tmp, &devlink->rate_list, list) {
                if (devlink_rate_is_node(devlink_rate)) {
                        ops->rate_node_del(devlink_rate, devlink_rate->priv, NULL);
                        list_del(&devlink_rate->list);
                        kfree(devlink_rate->name);
                        kfree(devlink_rate);
                }
        }
}
EXPORT_SYMBOL_GPL(devl_rate_nodes_destroy);