#include <linux/net.h>
#include <linux/netdevice.h>
#include <linux/udp.h>
#include "ovpnpriv.h"
#include "main.h"
#include "io.h"
#include "peer.h"
#include "socket.h"
#include "tcp.h"
#include "udp.h"
static void ovpn_socket_release_kref(struct kref *kref)
{
struct ovpn_socket *sock = container_of(kref, struct ovpn_socket,
refcount);
if (sock->sk->sk_protocol == IPPROTO_UDP)
ovpn_udp_socket_detach(sock);
else if (sock->sk->sk_protocol == IPPROTO_TCP)
ovpn_tcp_socket_detach(sock);
}
static bool ovpn_socket_put(struct ovpn_peer *peer, struct ovpn_socket *sock)
{
return kref_put(&sock->refcount, ovpn_socket_release_kref);
}
void ovpn_socket_release(struct ovpn_peer *peer)
{
struct ovpn_socket *sock;
bool released;
might_sleep();
sock = rcu_replace_pointer(peer->sock, NULL, true);
if (!sock)
return;
lock_sock(sock->sk);
released = ovpn_socket_put(peer, sock);
release_sock(sock->sk);
synchronize_rcu();
if (released) {
if (sock->sk->sk_protocol == IPPROTO_UDP) {
netdev_put(sock->ovpn->dev, &sock->dev_tracker);
} else if (sock->sk->sk_protocol == IPPROTO_TCP) {
ovpn_tcp_socket_wait_finish(sock);
ovpn_peer_put(sock->peer);
}
sock_put(sock->sk);
kfree(sock);
}
}
static bool ovpn_socket_hold(struct ovpn_socket *sock)
{
return kref_get_unless_zero(&sock->refcount);
}
static int ovpn_socket_attach(struct ovpn_socket *ovpn_sock,
struct socket *sock,
struct ovpn_peer *peer)
{
if (sock->sk->sk_protocol == IPPROTO_UDP)
return ovpn_udp_socket_attach(ovpn_sock, sock, peer->ovpn);
else if (sock->sk->sk_protocol == IPPROTO_TCP)
return ovpn_tcp_socket_attach(ovpn_sock, peer);
return -EOPNOTSUPP;
}
struct ovpn_socket *ovpn_socket_new(struct socket *sock, struct ovpn_peer *peer)
{
struct ovpn_socket *ovpn_sock;
struct sock *sk = sock->sk;
int ret;
lock_sock(sk);
if (sk->sk_protocol == IPPROTO_TCP && sk->sk_user_data) {
ovpn_sock = ERR_PTR(-EBUSY);
goto sock_release;
}
if (sk->sk_protocol == IPPROTO_UDP) {
u8 type = READ_ONCE(udp_sk(sk)->encap_type);
if (type && type != UDP_ENCAP_OVPNINUDP) {
ovpn_sock = ERR_PTR(-EBUSY);
goto sock_release;
}
rcu_read_lock();
ovpn_sock = rcu_dereference_sk_user_data(sk);
if (ovpn_sock) {
if (ovpn_sock->ovpn != peer->ovpn) {
ovpn_sock = ERR_PTR(-EBUSY);
rcu_read_unlock();
goto sock_release;
}
if (WARN_ON(!ovpn_socket_hold(ovpn_sock))) {
ovpn_sock = ERR_PTR(-EAGAIN);
rcu_read_unlock();
goto sock_release;
}
rcu_read_unlock();
goto sock_release;
}
rcu_read_unlock();
}
ovpn_sock = kzalloc_obj(*ovpn_sock);
if (!ovpn_sock) {
ovpn_sock = ERR_PTR(-ENOMEM);
goto sock_release;
}
ovpn_sock->sk = sk;
kref_init(&ovpn_sock->refcount);
if (sk->sk_protocol == IPPROTO_TCP) {
INIT_WORK(&ovpn_sock->tcp_tx_work, ovpn_tcp_tx_work);
ovpn_sock->peer = peer;
ovpn_peer_hold(peer);
} else if (sk->sk_protocol == IPPROTO_UDP) {
ovpn_sock->ovpn = peer->ovpn;
netdev_hold(peer->ovpn->dev, &ovpn_sock->dev_tracker,
GFP_KERNEL);
}
sock_hold(sk);
ret = ovpn_socket_attach(ovpn_sock, sock, peer);
if (ret < 0) {
if (sk->sk_protocol == IPPROTO_TCP)
ovpn_peer_put(peer);
else if (sk->sk_protocol == IPPROTO_UDP)
netdev_put(peer->ovpn->dev, &ovpn_sock->dev_tracker);
sock_put(sk);
kfree(ovpn_sock);
ovpn_sock = ERR_PTR(ret);
}
sock_release:
release_sock(sk);
return ovpn_sock;
}