root/drivers/net/ethernet/qualcomm/rmnet/rmnet_config.c
// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (c) 2013-2018, The Linux Foundation. All rights reserved.
 *
 * RMNET configuration engine
 */

#include <net/sock.h>
#include <linux/module.h>
#include <linux/netlink.h>
#include <linux/netdevice.h>
#include "rmnet_config.h"
#include "rmnet_handlers.h"
#include "rmnet_vnd.h"
#include "rmnet_private.h"
#include "rmnet_map.h"

/* Local Definitions and Declarations */

static const struct nla_policy rmnet_policy[IFLA_RMNET_MAX + 1] = {
        [IFLA_RMNET_MUX_ID]     = { .type = NLA_U16 },
        [IFLA_RMNET_FLAGS]      = { .len = sizeof(struct ifla_rmnet_flags) },
};

static int rmnet_is_real_dev_registered(const struct net_device *real_dev)
{
        return rcu_access_pointer(real_dev->rx_handler) == rmnet_rx_handler;
}

/* Needs rtnl lock */
struct rmnet_port*
rmnet_get_port_rtnl(const struct net_device *real_dev)
{
        return rtnl_dereference(real_dev->rx_handler_data);
}

static int rmnet_unregister_real_device(struct net_device *real_dev)
{
        struct rmnet_port *port = rmnet_get_port_rtnl(real_dev);

        if (port->nr_rmnet_devs)
                return -EINVAL;

        rmnet_map_tx_aggregate_exit(port);

        netdev_rx_handler_unregister(real_dev);

        kfree(port);

        netdev_dbg(real_dev, "Removed from rmnet\n");
        return 0;
}

static int rmnet_register_real_device(struct net_device *real_dev,
                                      struct netlink_ext_ack *extack)
{
        struct rmnet_port *port;
        int rc, entry;

        ASSERT_RTNL();

        if (rmnet_is_real_dev_registered(real_dev)) {
                port = rmnet_get_port_rtnl(real_dev);
                if (port->rmnet_mode != RMNET_EPMODE_VND) {
                        NL_SET_ERR_MSG_MOD(extack, "bridge device already exists");
                        return -EINVAL;
                }

                return 0;
        }

        port = kzalloc_obj(*port);
        if (!port)
                return -ENOMEM;

        port->dev = real_dev;
        rc = netdev_rx_handler_register(real_dev, rmnet_rx_handler, port);
        if (rc) {
                kfree(port);
                return -EBUSY;
        }

        for (entry = 0; entry < RMNET_MAX_LOGICAL_EP; entry++)
                INIT_HLIST_HEAD(&port->muxed_ep[entry]);

        rmnet_map_tx_aggregate_init(port);

        netdev_dbg(real_dev, "registered with rmnet\n");
        return 0;
}

static void rmnet_unregister_bridge(struct rmnet_port *port)
{
        struct net_device *bridge_dev, *real_dev, *rmnet_dev;
        struct rmnet_port *real_port;

        if (port->rmnet_mode != RMNET_EPMODE_BRIDGE)
                return;

        rmnet_dev = port->rmnet_dev;
        if (!port->nr_rmnet_devs) {
                /* bridge device */
                real_dev = port->bridge_ep;
                bridge_dev = port->dev;

                real_port = rmnet_get_port_rtnl(real_dev);
                real_port->bridge_ep = NULL;
                real_port->rmnet_mode = RMNET_EPMODE_VND;
        } else {
                /* real device */
                bridge_dev = port->bridge_ep;

                port->bridge_ep = NULL;
                port->rmnet_mode = RMNET_EPMODE_VND;
        }

        netdev_upper_dev_unlink(bridge_dev, rmnet_dev);
        rmnet_unregister_real_device(bridge_dev);
}

static int rmnet_newlink(struct net_device *dev,
                         struct rtnl_newlink_params *params,
                         struct netlink_ext_ack *extack)
{
        struct net *link_net = rtnl_newlink_link_net(params);
        u32 data_format = RMNET_FLAGS_INGRESS_DEAGGREGATION;
        struct nlattr **data = params->data;
        struct nlattr **tb = params->tb;
        struct net_device *real_dev;
        int mode = RMNET_EPMODE_VND;
        struct rmnet_endpoint *ep;
        struct rmnet_port *port;
        int err = 0;
        u16 mux_id;

        if (!tb[IFLA_LINK]) {
                NL_SET_ERR_MSG_MOD(extack, "link not specified");
                return -EINVAL;
        }

        real_dev = __dev_get_by_index(link_net, nla_get_u32(tb[IFLA_LINK]));
        if (!real_dev) {
                NL_SET_ERR_MSG_MOD(extack, "link does not exist");
                return -ENODEV;
        }

        ep = kzalloc_obj(*ep);
        if (!ep)
                return -ENOMEM;

        mux_id = nla_get_u16(data[IFLA_RMNET_MUX_ID]);

        err = rmnet_register_real_device(real_dev, extack);
        if (err)
                goto err0;

        port = rmnet_get_port_rtnl(real_dev);
        err = rmnet_vnd_newlink(mux_id, dev, port, real_dev, ep, extack);
        if (err)
                goto err1;

        err = netdev_upper_dev_link(real_dev, dev, extack);
        if (err < 0)
                goto err2;

        port->rmnet_mode = mode;
        port->rmnet_dev = dev;

        hlist_add_head_rcu(&ep->hlnode, &port->muxed_ep[mux_id]);

        if (data[IFLA_RMNET_FLAGS]) {
                struct ifla_rmnet_flags *flags;

                flags = nla_data(data[IFLA_RMNET_FLAGS]);
                data_format &= ~flags->mask;
                data_format |= flags->flags & flags->mask;
        }

        netdev_dbg(dev, "data format [0x%08X]\n", data_format);
        port->data_format = data_format;

        return 0;

err2:
        unregister_netdevice(dev);
        rmnet_vnd_dellink(mux_id, port, ep);
err1:
        rmnet_unregister_real_device(real_dev);
err0:
        kfree(ep);
        return err;
}

static void rmnet_dellink(struct net_device *dev, struct list_head *head)
{
        struct rmnet_priv *priv = netdev_priv(dev);
        struct net_device *real_dev, *bridge_dev;
        struct rmnet_port *real_port, *bridge_port;
        struct rmnet_endpoint *ep;
        u8 mux_id = priv->mux_id;

        real_dev = priv->real_dev;

        if (!rmnet_is_real_dev_registered(real_dev))
                return;

        real_port = rmnet_get_port_rtnl(real_dev);
        bridge_dev = real_port->bridge_ep;
        if (bridge_dev) {
                bridge_port = rmnet_get_port_rtnl(bridge_dev);
                rmnet_unregister_bridge(bridge_port);
        }

        ep = rmnet_get_endpoint(real_port, mux_id);
        if (ep) {
                hlist_del_init_rcu(&ep->hlnode);
                rmnet_vnd_dellink(mux_id, real_port, ep);
                kfree(ep);
        }

        netdev_upper_dev_unlink(real_dev, dev);
        rmnet_unregister_real_device(real_dev);
        unregister_netdevice_queue(dev, head);
}

static void rmnet_force_unassociate_device(struct net_device *real_dev)
{
        struct hlist_node *tmp_ep;
        struct rmnet_endpoint *ep;
        struct rmnet_port *port;
        unsigned long bkt_ep;
        LIST_HEAD(list);

        port = rmnet_get_port_rtnl(real_dev);

        if (port->nr_rmnet_devs) {
                /* real device */
                rmnet_unregister_bridge(port);
                hash_for_each_safe(port->muxed_ep, bkt_ep, tmp_ep, ep, hlnode) {
                        unregister_netdevice_queue(ep->egress_dev, &list);
                        netdev_upper_dev_unlink(real_dev, ep->egress_dev);
                        rmnet_vnd_dellink(ep->mux_id, port, ep);
                        hlist_del_init_rcu(&ep->hlnode);
                        kfree(ep);
                }
                rmnet_unregister_real_device(real_dev);
                unregister_netdevice_many(&list);
        } else {
                rmnet_unregister_bridge(port);
        }
}

static int rmnet_config_notify_cb(struct notifier_block *nb,
                                  unsigned long event, void *data)
{
        struct net_device *real_dev = netdev_notifier_info_to_dev(data);

        if (!rmnet_is_real_dev_registered(real_dev))
                return NOTIFY_DONE;

        switch (event) {
        case NETDEV_UNREGISTER:
                netdev_dbg(real_dev, "Kernel unregister\n");
                rmnet_force_unassociate_device(real_dev);
                break;
        case NETDEV_CHANGEMTU:
                if (rmnet_vnd_validate_real_dev_mtu(real_dev))
                        return NOTIFY_BAD;
                break;
        default:
                break;
        }

        return NOTIFY_DONE;
}

static struct notifier_block rmnet_dev_notifier __read_mostly = {
        .notifier_call = rmnet_config_notify_cb,
};

static int rmnet_rtnl_validate(struct nlattr *tb[], struct nlattr *data[],
                               struct netlink_ext_ack *extack)
{
        u16 mux_id;

        if (!data || !data[IFLA_RMNET_MUX_ID]) {
                NL_SET_ERR_MSG_MOD(extack, "MUX ID not specified");
                return -EINVAL;
        }

        mux_id = nla_get_u16(data[IFLA_RMNET_MUX_ID]);
        if (mux_id > (RMNET_MAX_LOGICAL_EP - 1)) {
                NL_SET_ERR_MSG_MOD(extack, "invalid MUX ID");
                return -ERANGE;
        }

        return 0;
}

static int rmnet_changelink(struct net_device *dev, struct nlattr *tb[],
                            struct nlattr *data[],
                            struct netlink_ext_ack *extack)
{
        struct rmnet_priv *priv = netdev_priv(dev);
        struct net_device *real_dev;
        struct rmnet_port *port;
        u16 mux_id;

        if (!dev)
                return -ENODEV;

        real_dev = priv->real_dev;
        if (!rmnet_is_real_dev_registered(real_dev))
                return -ENODEV;

        port = rmnet_get_port_rtnl(real_dev);

        if (data[IFLA_RMNET_MUX_ID]) {
                mux_id = nla_get_u16(data[IFLA_RMNET_MUX_ID]);

                if (mux_id != priv->mux_id) {
                        struct rmnet_endpoint *ep;

                        ep = rmnet_get_endpoint(port, priv->mux_id);
                        if (!ep)
                                return -ENODEV;

                        if (rmnet_get_endpoint(port, mux_id)) {
                                NL_SET_ERR_MSG_MOD(extack,
                                                   "MUX ID already exists");
                                return -EINVAL;
                        }

                        hlist_del_init_rcu(&ep->hlnode);
                        hlist_add_head_rcu(&ep->hlnode,
                                           &port->muxed_ep[mux_id]);

                        ep->mux_id = mux_id;
                        priv->mux_id = mux_id;
                }
        }

        if (data[IFLA_RMNET_FLAGS]) {
                struct ifla_rmnet_flags *flags;
                u32 old_data_format;

                old_data_format = port->data_format;
                flags = nla_data(data[IFLA_RMNET_FLAGS]);
                port->data_format &= ~flags->mask;
                port->data_format |= flags->flags & flags->mask;

                if (rmnet_vnd_update_dev_mtu(port, real_dev)) {
                        port->data_format = old_data_format;
                        NL_SET_ERR_MSG_MOD(extack, "Invalid MTU on real dev");
                        return -EINVAL;
                }
        }

        return 0;
}

static size_t rmnet_get_size(const struct net_device *dev)
{
        return
                /* IFLA_RMNET_MUX_ID */
                nla_total_size(2) +
                /* IFLA_RMNET_FLAGS */
                nla_total_size(sizeof(struct ifla_rmnet_flags));
}

static int rmnet_fill_info(struct sk_buff *skb, const struct net_device *dev)
{
        struct rmnet_priv *priv = netdev_priv(dev);
        struct net_device *real_dev;
        struct ifla_rmnet_flags f;
        struct rmnet_port *port;

        real_dev = priv->real_dev;

        if (nla_put_u16(skb, IFLA_RMNET_MUX_ID, priv->mux_id))
                goto nla_put_failure;

        if (rmnet_is_real_dev_registered(real_dev)) {
                port = rmnet_get_port_rtnl(real_dev);
                f.flags = port->data_format;
        } else {
                f.flags = 0;
        }

        f.mask  = ~0;

        if (nla_put(skb, IFLA_RMNET_FLAGS, sizeof(f), &f))
                goto nla_put_failure;

        return 0;

nla_put_failure:
        return -EMSGSIZE;
}

struct rtnl_link_ops rmnet_link_ops __read_mostly = {
        .kind           = "rmnet",
        .maxtype        = IFLA_RMNET_MAX,
        .priv_size      = sizeof(struct rmnet_priv),
        .setup          = rmnet_vnd_setup,
        .validate       = rmnet_rtnl_validate,
        .newlink        = rmnet_newlink,
        .dellink        = rmnet_dellink,
        .get_size       = rmnet_get_size,
        .changelink     = rmnet_changelink,
        .policy         = rmnet_policy,
        .fill_info      = rmnet_fill_info,
};

struct rmnet_port *rmnet_get_port_rcu(struct net_device *real_dev)
{
        if (rmnet_is_real_dev_registered(real_dev))
                return rcu_dereference_bh(real_dev->rx_handler_data);
        else
                return NULL;
}

struct rmnet_endpoint *rmnet_get_endpoint(struct rmnet_port *port, u8 mux_id)
{
        struct rmnet_endpoint *ep;

        hlist_for_each_entry_rcu(ep, &port->muxed_ep[mux_id], hlnode) {
                if (ep->mux_id == mux_id)
                        return ep;
        }

        return NULL;
}

int rmnet_add_bridge(struct net_device *rmnet_dev,
                     struct net_device *slave_dev,
                     struct netlink_ext_ack *extack)
{
        struct rmnet_priv *priv = netdev_priv(rmnet_dev);
        struct net_device *real_dev = priv->real_dev;
        struct rmnet_port *port, *slave_port;
        int err;

        port = rmnet_get_port_rtnl(real_dev);

        /* If there is more than one rmnet dev attached, its probably being
         * used for muxing. Skip the briding in that case
         */
        if (port->nr_rmnet_devs > 1) {
                NL_SET_ERR_MSG_MOD(extack, "more than one rmnet dev attached");
                return -EINVAL;
        }

        if (port->rmnet_mode != RMNET_EPMODE_VND) {
                NL_SET_ERR_MSG_MOD(extack, "more than one bridge dev attached");
                return -EINVAL;
        }

        if (rmnet_is_real_dev_registered(slave_dev)) {
                NL_SET_ERR_MSG_MOD(extack,
                                   "slave cannot be another rmnet dev");

                return -EBUSY;
        }

        err = rmnet_register_real_device(slave_dev, extack);
        if (err)
                return -EBUSY;

        err = netdev_master_upper_dev_link(slave_dev, rmnet_dev, NULL, NULL,
                                           extack);
        if (err) {
                rmnet_unregister_real_device(slave_dev);
                return err;
        }

        slave_port = rmnet_get_port_rtnl(slave_dev);
        slave_port->rmnet_mode = RMNET_EPMODE_BRIDGE;
        slave_port->bridge_ep = real_dev;
        slave_port->rmnet_dev = rmnet_dev;

        port->rmnet_mode = RMNET_EPMODE_BRIDGE;
        port->bridge_ep = slave_dev;

        netdev_dbg(slave_dev, "registered with rmnet as slave\n");
        return 0;
}

int rmnet_del_bridge(struct net_device *rmnet_dev,
                     struct net_device *slave_dev)
{
        struct rmnet_port *port = rmnet_get_port_rtnl(slave_dev);

        rmnet_unregister_bridge(port);

        netdev_dbg(slave_dev, "removed from rmnet as slave\n");
        return 0;
}

/* Startup/Shutdown */

static int __init rmnet_init(void)
{
        int rc;

        rc = register_netdevice_notifier(&rmnet_dev_notifier);
        if (rc != 0)
                return rc;

        rc = rtnl_link_register(&rmnet_link_ops);
        if (rc != 0) {
                unregister_netdevice_notifier(&rmnet_dev_notifier);
                return rc;
        }
        return rc;
}

static void __exit rmnet_exit(void)
{
        rtnl_link_unregister(&rmnet_link_ops);
        unregister_netdevice_notifier(&rmnet_dev_notifier);
}

module_init(rmnet_init)
module_exit(rmnet_exit)
MODULE_ALIAS_RTNL_LINK("rmnet");
MODULE_DESCRIPTION("Qualcomm RmNet MAP driver");
MODULE_LICENSE("GPL v2");