#include <linux/skbuff.h>
#include <linux/list.h>
#include <linux/hashtable.h>
#include <net/ip6_route.h>
#include "ovpnpriv.h"
#include "bind.h"
#include "pktid.h"
#include "crypto.h"
#include "io.h"
#include "main.h"
#include "netlink.h"
#include "peer.h"
#include "socket.h"
static void unlock_ovpn(struct ovpn_priv *ovpn,
struct llist_head *release_list)
__releases(&ovpn->lock)
{
struct ovpn_peer *peer;
spin_unlock_bh(&ovpn->lock);
llist_for_each_entry(peer, release_list->first, release_entry) {
ovpn_socket_release(peer);
ovpn_peer_put(peer);
}
}
void ovpn_peer_keepalive_set(struct ovpn_peer *peer, u32 interval, u32 timeout)
{
time64_t now = ktime_get_real_seconds();
netdev_dbg(peer->ovpn->dev,
"scheduling keepalive for peer %u: interval=%u timeout=%u\n",
peer->id, interval, timeout);
peer->keepalive_interval = interval;
WRITE_ONCE(peer->last_sent, now);
peer->keepalive_xmit_exp = now + interval;
peer->keepalive_timeout = timeout;
WRITE_ONCE(peer->last_recv, now);
peer->keepalive_recv_exp = now + timeout;
mod_delayed_work(system_percpu_wq, &peer->ovpn->keepalive_work, 0);
}
static void ovpn_peer_keepalive_send(struct work_struct *work)
{
struct ovpn_peer *peer = container_of(work, struct ovpn_peer,
keepalive_work);
local_bh_disable();
ovpn_xmit_special(peer, ovpn_keepalive_message,
sizeof(ovpn_keepalive_message));
local_bh_enable();
}
struct ovpn_peer *ovpn_peer_new(struct ovpn_priv *ovpn, u32 id)
{
struct ovpn_peer *peer;
int ret;
peer = kzalloc_obj(*peer);
if (!peer)
return ERR_PTR(-ENOMEM);
peer->id = id;
peer->ovpn = ovpn;
peer->vpn_addrs.ipv4.s_addr = htonl(INADDR_ANY);
peer->vpn_addrs.ipv6 = in6addr_any;
RCU_INIT_POINTER(peer->bind, NULL);
ovpn_crypto_state_init(&peer->crypto);
spin_lock_init(&peer->lock);
kref_init(&peer->refcount);
ovpn_peer_stats_init(&peer->vpn_stats);
ovpn_peer_stats_init(&peer->link_stats);
INIT_WORK(&peer->keepalive_work, ovpn_peer_keepalive_send);
ret = dst_cache_init(&peer->dst_cache, GFP_KERNEL);
if (ret < 0) {
netdev_err(ovpn->dev,
"cannot initialize dst cache for peer %u\n",
peer->id);
kfree(peer);
return ERR_PTR(ret);
}
netdev_hold(ovpn->dev, &peer->dev_tracker, GFP_KERNEL);
return peer;
}
int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer,
const struct sockaddr_storage *ss,
const void *local_ip)
{
struct ovpn_bind *bind;
size_t ip_len;
lockdep_assert_held(&peer->lock);
bind = ovpn_bind_from_sockaddr(ss);
if (IS_ERR(bind))
return PTR_ERR(bind);
if (local_ip) {
if (ss->ss_family == AF_INET) {
ip_len = sizeof(struct in_addr);
} else if (ss->ss_family == AF_INET6) {
ip_len = sizeof(struct in6_addr);
} else {
net_dbg_ratelimited("%s: invalid family %u for remote endpoint for peer %u\n",
netdev_name(peer->ovpn->dev),
ss->ss_family, peer->id);
kfree(bind);
return -EINVAL;
}
memcpy(&bind->local, local_ip, ip_len);
}
ovpn_bind_reset(peer, bind);
return 0;
}
#define ovpn_get_hash_slot(_tbl, _key, _key_len) ({ \
typeof(_tbl) *__tbl2 = &(_tbl); \
jhash(_key, _key_len, 0) % HASH_SIZE(*__tbl2); \
})
#define ovpn_get_hash_head(_tbl, _key, _key_len) ({ \
typeof(_tbl) *__tbl1 = &(_tbl); \
&(*__tbl1)[ovpn_get_hash_slot(*__tbl1, _key, _key_len)];\
})
void ovpn_peer_endpoints_update(struct ovpn_peer *peer, struct sk_buff *skb)
{
struct hlist_nulls_head *nhead;
struct sockaddr_storage ss;
struct sockaddr_in6 *sa6;
bool reset_cache = false;
struct sockaddr_in *sa;
struct ovpn_bind *bind;
const void *local_ip;
size_t salen = 0;
spin_lock_bh(&peer->lock);
bind = rcu_dereference_protected(peer->bind,
lockdep_is_held(&peer->lock));
if (unlikely(!bind))
goto unlock;
switch (skb->protocol) {
case htons(ETH_P_IP):
if (unlikely(!ovpn_bind_skb_src_match(bind, skb))) {
local_ip = &ip_hdr(skb)->daddr;
sa = (struct sockaddr_in *)&ss;
sa->sin_family = AF_INET;
sa->sin_addr.s_addr = ip_hdr(skb)->saddr;
sa->sin_port = udp_hdr(skb)->source;
salen = sizeof(*sa);
reset_cache = true;
break;
}
if (unlikely(bind->local.ipv4.s_addr != ip_hdr(skb)->daddr)) {
net_dbg_ratelimited("%s: learning local IPv4 for peer %d (%pI4 -> %pI4)\n",
netdev_name(peer->ovpn->dev),
peer->id, &bind->local.ipv4.s_addr,
&ip_hdr(skb)->daddr);
bind->local.ipv4.s_addr = ip_hdr(skb)->daddr;
reset_cache = true;
}
break;
case htons(ETH_P_IPV6):
if (unlikely(!ovpn_bind_skb_src_match(bind, skb))) {
local_ip = &ipv6_hdr(skb)->daddr;
sa6 = (struct sockaddr_in6 *)&ss;
sa6->sin6_family = AF_INET6;
sa6->sin6_addr = ipv6_hdr(skb)->saddr;
sa6->sin6_port = udp_hdr(skb)->source;
sa6->sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr,
skb->skb_iif);
salen = sizeof(*sa6);
reset_cache = true;
break;
}
if (unlikely(!ipv6_addr_equal(&bind->local.ipv6,
&ipv6_hdr(skb)->daddr))) {
net_dbg_ratelimited("%s: learning local IPv6 for peer %d (%pI6c -> %pI6c)\n",
netdev_name(peer->ovpn->dev),
peer->id, &bind->local.ipv6,
&ipv6_hdr(skb)->daddr);
bind->local.ipv6 = ipv6_hdr(skb)->daddr;
reset_cache = true;
}
break;
default:
goto unlock;
}
if (unlikely(reset_cache))
dst_cache_reset(&peer->dst_cache);
if (likely(!salen))
goto unlock;
if (unlikely(ovpn_peer_reset_sockaddr(peer,
(struct sockaddr_storage *)&ss,
local_ip) < 0))
goto unlock;
net_dbg_ratelimited("%s: peer %d floated to %pIScp",
netdev_name(peer->ovpn->dev), peer->id, &ss);
spin_unlock_bh(&peer->lock);
if (peer->ovpn->mode == OVPN_MODE_MP) {
spin_lock_bh(&peer->ovpn->lock);
spin_lock_bh(&peer->lock);
bind = rcu_dereference_protected(peer->bind,
lockdep_is_held(&peer->lock));
if (unlikely(!bind)) {
spin_unlock_bh(&peer->lock);
spin_unlock_bh(&peer->ovpn->lock);
return;
}
switch (bind->remote.in4.sin_family) {
case AF_INET:
salen = sizeof(*sa);
break;
case AF_INET6:
salen = sizeof(*sa6);
break;
}
hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr);
nhead = ovpn_get_hash_head(peer->ovpn->peers->by_transp_addr,
&bind->remote, salen);
hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead);
spin_unlock_bh(&peer->lock);
spin_unlock_bh(&peer->ovpn->lock);
}
return;
unlock:
spin_unlock_bh(&peer->lock);
}
static void ovpn_peer_release_rcu(struct rcu_head *head)
{
struct ovpn_peer *peer = container_of(head, struct ovpn_peer, rcu);
dst_cache_destroy(&peer->dst_cache);
kfree(peer);
}
void ovpn_peer_release(struct ovpn_peer *peer)
{
ovpn_crypto_state_release(&peer->crypto);
spin_lock_bh(&peer->lock);
ovpn_bind_reset(peer, NULL);
spin_unlock_bh(&peer->lock);
call_rcu(&peer->rcu, ovpn_peer_release_rcu);
netdev_put(peer->ovpn->dev, &peer->dev_tracker);
}
void ovpn_peer_release_kref(struct kref *kref)
{
struct ovpn_peer *peer = container_of(kref, struct ovpn_peer, refcount);
ovpn_peer_release(peer);
}
static int ovpn_peer_skb_to_sockaddr(struct sk_buff *skb,
struct sockaddr_storage *ss)
{
struct sockaddr_in6 *sa6;
struct sockaddr_in *sa4;
switch (skb->protocol) {
case htons(ETH_P_IP):
sa4 = (struct sockaddr_in *)ss;
sa4->sin_family = AF_INET;
sa4->sin_addr.s_addr = ip_hdr(skb)->saddr;
sa4->sin_port = udp_hdr(skb)->source;
return sizeof(*sa4);
case htons(ETH_P_IPV6):
sa6 = (struct sockaddr_in6 *)ss;
sa6->sin6_family = AF_INET6;
sa6->sin6_addr = ipv6_hdr(skb)->saddr;
sa6->sin6_port = udp_hdr(skb)->source;
return sizeof(*sa6);
}
return -1;
}
static __be32 ovpn_nexthop_from_skb4(struct sk_buff *skb)
{
const struct rtable *rt = skb_rtable(skb);
if (rt && rt->rt_uses_gateway)
return rt->rt_gw4;
return ip_hdr(skb)->daddr;
}
static struct in6_addr ovpn_nexthop_from_skb6(struct sk_buff *skb)
{
const struct rt6_info *rt = skb_rt6_info(skb);
if (!rt || !(rt->rt6i_flags & RTF_GATEWAY))
return ipv6_hdr(skb)->daddr;
return rt->rt6i_gateway;
}
static struct ovpn_peer *ovpn_peer_get_by_vpn_addr4(struct ovpn_priv *ovpn,
__be32 addr)
{
struct hlist_nulls_head *nhead;
struct hlist_nulls_node *ntmp;
struct ovpn_peer *tmp;
unsigned int slot;
begin:
slot = ovpn_get_hash_slot(ovpn->peers->by_vpn_addr4, &addr,
sizeof(addr));
nhead = &ovpn->peers->by_vpn_addr4[slot];
hlist_nulls_for_each_entry_rcu(tmp, ntmp, nhead, hash_entry_addr4)
if (addr == tmp->vpn_addrs.ipv4.s_addr)
return tmp;
if (get_nulls_value(ntmp) != slot)
goto begin;
return NULL;
}
static struct ovpn_peer *ovpn_peer_get_by_vpn_addr6(struct ovpn_priv *ovpn,
struct in6_addr *addr)
{
struct hlist_nulls_head *nhead;
struct hlist_nulls_node *ntmp;
struct ovpn_peer *tmp;
unsigned int slot;
begin:
slot = ovpn_get_hash_slot(ovpn->peers->by_vpn_addr6, addr,
sizeof(*addr));
nhead = &ovpn->peers->by_vpn_addr6[slot];
hlist_nulls_for_each_entry_rcu(tmp, ntmp, nhead, hash_entry_addr6)
if (ipv6_addr_equal(addr, &tmp->vpn_addrs.ipv6))
return tmp;
if (get_nulls_value(ntmp) != slot)
goto begin;
return NULL;
}
static bool ovpn_peer_transp_match(const struct ovpn_peer *peer,
const struct sockaddr_storage *ss)
{
struct ovpn_bind *bind = rcu_dereference(peer->bind);
struct sockaddr_in6 *sa6;
struct sockaddr_in *sa4;
if (unlikely(!bind))
return false;
if (ss->ss_family != bind->remote.in4.sin_family)
return false;
switch (ss->ss_family) {
case AF_INET:
sa4 = (struct sockaddr_in *)ss;
if (sa4->sin_addr.s_addr != bind->remote.in4.sin_addr.s_addr)
return false;
if (sa4->sin_port != bind->remote.in4.sin_port)
return false;
break;
case AF_INET6:
sa6 = (struct sockaddr_in6 *)ss;
if (!ipv6_addr_equal(&sa6->sin6_addr,
&bind->remote.in6.sin6_addr))
return false;
if (sa6->sin6_port != bind->remote.in6.sin6_port)
return false;
break;
default:
return false;
}
return true;
}
static struct ovpn_peer *
ovpn_peer_get_by_transp_addr_p2p(struct ovpn_priv *ovpn,
struct sockaddr_storage *ss)
{
struct ovpn_peer *tmp, *peer = NULL;
rcu_read_lock();
tmp = rcu_dereference(ovpn->peer);
if (likely(tmp && ovpn_peer_transp_match(tmp, ss) &&
ovpn_peer_hold(tmp)))
peer = tmp;
rcu_read_unlock();
return peer;
}
struct ovpn_peer *ovpn_peer_get_by_transp_addr(struct ovpn_priv *ovpn,
struct sk_buff *skb)
{
struct ovpn_peer *tmp, *peer = NULL;
struct sockaddr_storage ss = { 0 };
struct hlist_nulls_head *nhead;
struct hlist_nulls_node *ntmp;
unsigned int slot;
ssize_t sa_len;
sa_len = ovpn_peer_skb_to_sockaddr(skb, &ss);
if (unlikely(sa_len < 0))
return NULL;
if (ovpn->mode == OVPN_MODE_P2P)
return ovpn_peer_get_by_transp_addr_p2p(ovpn, &ss);
rcu_read_lock();
begin:
slot = ovpn_get_hash_slot(ovpn->peers->by_transp_addr, &ss, sa_len);
nhead = &ovpn->peers->by_transp_addr[slot];
hlist_nulls_for_each_entry_rcu(tmp, ntmp, nhead,
hash_entry_transp_addr) {
if (!ovpn_peer_transp_match(tmp, &ss))
continue;
if (!ovpn_peer_hold(tmp))
continue;
peer = tmp;
break;
}
if (!peer && get_nulls_value(ntmp) != slot)
goto begin;
rcu_read_unlock();
return peer;
}
static struct ovpn_peer *ovpn_peer_get_by_id_p2p(struct ovpn_priv *ovpn,
u32 peer_id)
{
struct ovpn_peer *tmp, *peer = NULL;
rcu_read_lock();
tmp = rcu_dereference(ovpn->peer);
if (likely(tmp && tmp->id == peer_id && ovpn_peer_hold(tmp)))
peer = tmp;
rcu_read_unlock();
return peer;
}
struct ovpn_peer *ovpn_peer_get_by_id(struct ovpn_priv *ovpn, u32 peer_id)
{
struct ovpn_peer *tmp, *peer = NULL;
struct hlist_head *head;
if (ovpn->mode == OVPN_MODE_P2P)
return ovpn_peer_get_by_id_p2p(ovpn, peer_id);
head = ovpn_get_hash_head(ovpn->peers->by_id, &peer_id,
sizeof(peer_id));
rcu_read_lock();
hlist_for_each_entry_rcu(tmp, head, hash_entry_id) {
if (tmp->id != peer_id)
continue;
if (!ovpn_peer_hold(tmp))
continue;
peer = tmp;
break;
}
rcu_read_unlock();
return peer;
}
static void ovpn_peer_remove(struct ovpn_peer *peer,
enum ovpn_del_peer_reason reason,
struct llist_head *release_list)
{
lockdep_assert_held(&peer->ovpn->lock);
switch (peer->ovpn->mode) {
case OVPN_MODE_MP:
if (hlist_unhashed(&peer->hash_entry_id))
return;
hlist_del_init_rcu(&peer->hash_entry_id);
hlist_nulls_del_init_rcu(&peer->hash_entry_addr4);
hlist_nulls_del_init_rcu(&peer->hash_entry_addr6);
hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr);
break;
case OVPN_MODE_P2P:
if (peer != rcu_access_pointer(peer->ovpn->peer))
return;
RCU_INIT_POINTER(peer->ovpn->peer, NULL);
netif_carrier_off(peer->ovpn->dev);
break;
}
peer->delete_reason = reason;
ovpn_nl_peer_del_notify(peer);
llist_add(&peer->release_entry, release_list);
}
struct ovpn_peer *ovpn_peer_get_by_dst(struct ovpn_priv *ovpn,
struct sk_buff *skb)
{
struct ovpn_peer *peer = NULL;
struct in6_addr addr6;
__be32 addr4;
if (ovpn->mode == OVPN_MODE_P2P) {
rcu_read_lock();
peer = rcu_dereference(ovpn->peer);
if (unlikely(peer && !ovpn_peer_hold(peer)))
peer = NULL;
rcu_read_unlock();
return peer;
}
rcu_read_lock();
switch (skb->protocol) {
case htons(ETH_P_IP):
addr4 = ovpn_nexthop_from_skb4(skb);
peer = ovpn_peer_get_by_vpn_addr4(ovpn, addr4);
break;
case htons(ETH_P_IPV6):
addr6 = ovpn_nexthop_from_skb6(skb);
peer = ovpn_peer_get_by_vpn_addr6(ovpn, &addr6);
break;
}
if (unlikely(peer && !ovpn_peer_hold(peer)))
peer = NULL;
rcu_read_unlock();
return peer;
}
static __be32 ovpn_nexthop_from_rt4(struct ovpn_priv *ovpn, __be32 dest)
{
struct rtable *rt;
struct flowi4 fl = {
.daddr = dest
};
rt = ip_route_output_flow(dev_net(ovpn->dev), &fl, NULL);
if (IS_ERR(rt)) {
net_dbg_ratelimited("%s: no route to host %pI4\n",
netdev_name(ovpn->dev), &dest);
return dest;
}
if (!rt->rt_uses_gateway)
goto out;
dest = rt->rt_gw4;
out:
ip_rt_put(rt);
return dest;
}
static struct in6_addr ovpn_nexthop_from_rt6(struct ovpn_priv *ovpn,
struct in6_addr dest)
{
#if IS_ENABLED(CONFIG_IPV6)
struct dst_entry *entry;
struct rt6_info *rt;
struct flowi6 fl = {
.daddr = dest,
};
entry = ipv6_stub->ipv6_dst_lookup_flow(dev_net(ovpn->dev), NULL, &fl,
NULL);
if (IS_ERR(entry)) {
net_dbg_ratelimited("%s: no route to host %pI6c\n",
netdev_name(ovpn->dev), &dest);
return dest;
}
rt = dst_rt6_info(entry);
if (!(rt->rt6i_flags & RTF_GATEWAY))
goto out;
dest = rt->rt6i_gateway;
out:
dst_release((struct dst_entry *)rt);
#endif
return dest;
}
bool ovpn_peer_check_by_src(struct ovpn_priv *ovpn, struct sk_buff *skb,
struct ovpn_peer *peer)
{
bool match = false;
struct in6_addr addr6;
__be32 addr4;
if (ovpn->mode == OVPN_MODE_P2P) {
return peer == rcu_access_pointer(ovpn->peer);
}
switch (skb->protocol) {
case htons(ETH_P_IP):
addr4 = ovpn_nexthop_from_rt4(ovpn, ip_hdr(skb)->saddr);
rcu_read_lock();
match = (peer == ovpn_peer_get_by_vpn_addr4(ovpn, addr4));
rcu_read_unlock();
break;
case htons(ETH_P_IPV6):
addr6 = ovpn_nexthop_from_rt6(ovpn, ipv6_hdr(skb)->saddr);
rcu_read_lock();
match = (peer == ovpn_peer_get_by_vpn_addr6(ovpn, &addr6));
rcu_read_unlock();
break;
}
return match;
}
void ovpn_peer_hash_vpn_ip(struct ovpn_peer *peer)
{
struct hlist_nulls_head *nhead;
lockdep_assert_held(&peer->ovpn->lock);
if (peer->ovpn->mode != OVPN_MODE_MP)
return;
if (peer->vpn_addrs.ipv4.s_addr != htonl(INADDR_ANY)) {
hlist_nulls_del_init_rcu(&peer->hash_entry_addr4);
nhead = ovpn_get_hash_head(peer->ovpn->peers->by_vpn_addr4,
&peer->vpn_addrs.ipv4,
sizeof(peer->vpn_addrs.ipv4));
hlist_nulls_add_head_rcu(&peer->hash_entry_addr4, nhead);
}
if (!ipv6_addr_any(&peer->vpn_addrs.ipv6)) {
hlist_nulls_del_init_rcu(&peer->hash_entry_addr6);
nhead = ovpn_get_hash_head(peer->ovpn->peers->by_vpn_addr6,
&peer->vpn_addrs.ipv6,
sizeof(peer->vpn_addrs.ipv6));
hlist_nulls_add_head_rcu(&peer->hash_entry_addr6, nhead);
}
}
static int ovpn_peer_add_mp(struct ovpn_priv *ovpn, struct ovpn_peer *peer)
{
struct sockaddr_storage sa = { 0 };
struct hlist_nulls_head *nhead;
struct sockaddr_in6 *sa6;
struct sockaddr_in *sa4;
struct ovpn_bind *bind;
struct ovpn_peer *tmp;
size_t salen;
int ret = 0;
spin_lock_bh(&ovpn->lock);
tmp = ovpn_peer_get_by_id(ovpn, peer->id);
if (tmp) {
ovpn_peer_put(tmp);
ret = -EEXIST;
goto out;
}
bind = rcu_dereference_protected(peer->bind, true);
if (bind) {
switch (bind->remote.in4.sin_family) {
case AF_INET:
sa4 = (struct sockaddr_in *)&sa;
sa4->sin_family = AF_INET;
sa4->sin_addr.s_addr = bind->remote.in4.sin_addr.s_addr;
sa4->sin_port = bind->remote.in4.sin_port;
salen = sizeof(*sa4);
break;
case AF_INET6:
sa6 = (struct sockaddr_in6 *)&sa;
sa6->sin6_family = AF_INET6;
sa6->sin6_addr = bind->remote.in6.sin6_addr;
sa6->sin6_port = bind->remote.in6.sin6_port;
salen = sizeof(*sa6);
break;
default:
ret = -EPROTONOSUPPORT;
goto out;
}
nhead = ovpn_get_hash_head(ovpn->peers->by_transp_addr, &sa,
salen);
hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead);
}
hlist_add_head_rcu(&peer->hash_entry_id,
ovpn_get_hash_head(ovpn->peers->by_id, &peer->id,
sizeof(peer->id)));
ovpn_peer_hash_vpn_ip(peer);
out:
spin_unlock_bh(&ovpn->lock);
return ret;
}
static int ovpn_peer_add_p2p(struct ovpn_priv *ovpn, struct ovpn_peer *peer)
{
LLIST_HEAD(release_list);
struct ovpn_peer *tmp;
spin_lock_bh(&ovpn->lock);
tmp = rcu_dereference_protected(ovpn->peer,
lockdep_is_held(&ovpn->lock));
if (tmp)
ovpn_peer_remove(tmp, OVPN_DEL_PEER_REASON_TEARDOWN,
&release_list);
rcu_assign_pointer(ovpn->peer, peer);
netif_carrier_on(ovpn->dev);
unlock_ovpn(ovpn, &release_list);
return 0;
}
int ovpn_peer_add(struct ovpn_priv *ovpn, struct ovpn_peer *peer)
{
switch (ovpn->mode) {
case OVPN_MODE_MP:
return ovpn_peer_add_mp(ovpn, peer);
case OVPN_MODE_P2P:
return ovpn_peer_add_p2p(ovpn, peer);
}
return -EOPNOTSUPP;
}
static int ovpn_peer_del_mp(struct ovpn_peer *peer,
enum ovpn_del_peer_reason reason,
struct llist_head *release_list)
{
struct ovpn_peer *tmp;
int ret = -ENOENT;
lockdep_assert_held(&peer->ovpn->lock);
tmp = ovpn_peer_get_by_id(peer->ovpn, peer->id);
if (tmp == peer) {
ovpn_peer_remove(peer, reason, release_list);
ret = 0;
}
if (tmp)
ovpn_peer_put(tmp);
return ret;
}
static int ovpn_peer_del_p2p(struct ovpn_peer *peer,
enum ovpn_del_peer_reason reason,
struct llist_head *release_list)
{
struct ovpn_peer *tmp;
lockdep_assert_held(&peer->ovpn->lock);
tmp = rcu_dereference_protected(peer->ovpn->peer,
lockdep_is_held(&peer->ovpn->lock));
if (tmp != peer)
return -ENOENT;
ovpn_peer_remove(peer, reason, release_list);
return 0;
}
int ovpn_peer_del(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason)
{
LLIST_HEAD(release_list);
int ret = -EOPNOTSUPP;
spin_lock_bh(&peer->ovpn->lock);
switch (peer->ovpn->mode) {
case OVPN_MODE_MP:
ret = ovpn_peer_del_mp(peer, reason, &release_list);
break;
case OVPN_MODE_P2P:
ret = ovpn_peer_del_p2p(peer, reason, &release_list);
break;
default:
break;
}
unlock_ovpn(peer->ovpn, &release_list);
return ret;
}
static void ovpn_peer_release_p2p(struct ovpn_priv *ovpn, struct sock *sk,
enum ovpn_del_peer_reason reason)
{
struct ovpn_socket *ovpn_sock;
LLIST_HEAD(release_list);
struct ovpn_peer *peer;
spin_lock_bh(&ovpn->lock);
peer = rcu_dereference_protected(ovpn->peer,
lockdep_is_held(&ovpn->lock));
if (!peer) {
spin_unlock_bh(&ovpn->lock);
return;
}
if (sk) {
ovpn_sock = rcu_access_pointer(peer->sock);
if (!ovpn_sock || ovpn_sock->sk != sk) {
spin_unlock_bh(&ovpn->lock);
ovpn_peer_put(peer);
return;
}
}
ovpn_peer_remove(peer, reason, &release_list);
unlock_ovpn(ovpn, &release_list);
}
static void ovpn_peers_release_mp(struct ovpn_priv *ovpn, struct sock *sk,
enum ovpn_del_peer_reason reason)
{
struct ovpn_socket *ovpn_sock;
LLIST_HEAD(release_list);
struct ovpn_peer *peer;
struct hlist_node *tmp;
int bkt;
spin_lock_bh(&ovpn->lock);
hash_for_each_safe(ovpn->peers->by_id, bkt, tmp, peer, hash_entry_id) {
bool remove = true;
if (sk) {
rcu_read_lock();
ovpn_sock = rcu_dereference(peer->sock);
remove = ovpn_sock && ovpn_sock->sk == sk;
rcu_read_unlock();
}
if (remove)
ovpn_peer_remove(peer, reason, &release_list);
}
unlock_ovpn(ovpn, &release_list);
}
void ovpn_peers_free(struct ovpn_priv *ovpn, struct sock *sk,
enum ovpn_del_peer_reason reason)
{
switch (ovpn->mode) {
case OVPN_MODE_P2P:
ovpn_peer_release_p2p(ovpn, sk, reason);
break;
case OVPN_MODE_MP:
ovpn_peers_release_mp(ovpn, sk, reason);
break;
}
}
static time64_t ovpn_peer_keepalive_work_single(struct ovpn_peer *peer,
time64_t now,
struct llist_head *release_list)
{
time64_t last_recv, last_sent, next_run1, next_run2;
unsigned long timeout, interval;
bool expired;
spin_lock_bh(&peer->lock);
if (!peer->keepalive_timeout || !peer->keepalive_interval) {
spin_unlock_bh(&peer->lock);
return 0;
}
expired = false;
timeout = peer->keepalive_timeout;
last_recv = READ_ONCE(peer->last_recv);
if (now < last_recv + timeout) {
peer->keepalive_recv_exp = last_recv + timeout;
next_run1 = peer->keepalive_recv_exp;
} else if (peer->keepalive_recv_exp > now) {
next_run1 = peer->keepalive_recv_exp;
} else {
expired = true;
}
if (expired) {
spin_unlock_bh(&peer->lock);
netdev_dbg(peer->ovpn->dev, "peer %u expired\n",
peer->id);
ovpn_peer_remove(peer, OVPN_DEL_PEER_REASON_EXPIRED,
release_list);
return 0;
}
expired = false;
interval = peer->keepalive_interval;
last_sent = READ_ONCE(peer->last_sent);
if (now < last_sent + interval) {
peer->keepalive_xmit_exp = last_sent + interval;
next_run2 = peer->keepalive_xmit_exp;
} else if (peer->keepalive_xmit_exp > now) {
next_run2 = peer->keepalive_xmit_exp;
} else {
expired = true;
next_run2 = now + interval;
}
spin_unlock_bh(&peer->lock);
if (expired) {
netdev_dbg(peer->ovpn->dev,
"sending keepalive to peer %u\n",
peer->id);
if (schedule_work(&peer->keepalive_work))
ovpn_peer_hold(peer);
}
if (next_run1 < next_run2)
return next_run1;
return next_run2;
}
static time64_t ovpn_peer_keepalive_work_mp(struct ovpn_priv *ovpn,
time64_t now,
struct llist_head *release_list)
{
time64_t tmp_next_run, next_run = 0;
struct hlist_node *tmp;
struct ovpn_peer *peer;
int bkt;
lockdep_assert_held(&ovpn->lock);
hash_for_each_safe(ovpn->peers->by_id, bkt, tmp, peer, hash_entry_id) {
tmp_next_run = ovpn_peer_keepalive_work_single(peer, now,
release_list);
if (!tmp_next_run)
continue;
if (!next_run || tmp_next_run < next_run)
next_run = tmp_next_run;
}
return next_run;
}
static time64_t ovpn_peer_keepalive_work_p2p(struct ovpn_priv *ovpn,
time64_t now,
struct llist_head *release_list)
{
struct ovpn_peer *peer;
time64_t next_run = 0;
lockdep_assert_held(&ovpn->lock);
peer = rcu_dereference_protected(ovpn->peer,
lockdep_is_held(&ovpn->lock));
if (peer)
next_run = ovpn_peer_keepalive_work_single(peer, now,
release_list);
return next_run;
}
void ovpn_peer_keepalive_work(struct work_struct *work)
{
struct ovpn_priv *ovpn = container_of(work, struct ovpn_priv,
keepalive_work.work);
time64_t next_run = 0, now = ktime_get_real_seconds();
LLIST_HEAD(release_list);
spin_lock_bh(&ovpn->lock);
switch (ovpn->mode) {
case OVPN_MODE_MP:
next_run = ovpn_peer_keepalive_work_mp(ovpn, now,
&release_list);
break;
case OVPN_MODE_P2P:
next_run = ovpn_peer_keepalive_work_p2p(ovpn, now,
&release_list);
break;
}
if (next_run > 0) {
netdev_dbg(ovpn->dev,
"scheduling keepalive work: now=%llu next_run=%llu delta=%llu\n",
next_run, now, next_run - now);
schedule_delayed_work(&ovpn->keepalive_work,
(next_run - now) * HZ);
}
unlock_ovpn(ovpn, &release_list);
}