UPSTREAM: wireguard: device: avoid circular netns references

Before, we took a reference to the creating netns if the new netns was
different. This caused issues with circular references, with two
wireguard interfaces swapping namespaces. The solution is to rather not
take any extra references at all, but instead simply invalidate the
creating netns pointer when that netns is deleted.

In order to prevent this from happening again, this commit improves the
rough object leak tracking by allowing it to account for created and
destroyed interfaces, aside from just peers and keys. That then makes it
possible to check for the object leak when having two interfaces take a
reference to each others' namespaces.

Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
(cherry picked from commit 900575aa33a3eaaef802b31de187a85c4a4b4bd0)
Bug: 152722841
[Jason: netlink notifier uses exit instead of pre_exit]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@google.com>
Change-Id: Iea52fe3ca0e41318c392d9e91edb1856de6c9528
This commit is contained in:
Jason A. Donenfeld 2020-06-23 03:59:45 -06:00 committed by Greg Kroah-Hartman
parent ccd1d7a910
commit 33fa89ac6b
5 changed files with 67 additions and 46 deletions

View file

@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev)
if (dev_v6) if (dev_v6)
dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE; dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
mutex_lock(&wg->device_update_lock);
ret = wg_socket_init(wg, wg->incoming_port); ret = wg_socket_init(wg, wg->incoming_port);
if (ret < 0) if (ret < 0)
return ret; goto out;
mutex_lock(&wg->device_update_lock);
list_for_each_entry(peer, &wg->peer_list, peer_list) { list_for_each_entry(peer, &wg->peer_list, peer_list) {
wg_packet_send_staged_packets(peer); wg_packet_send_staged_packets(peer);
if (peer->persistent_keepalive_interval) if (peer->persistent_keepalive_interval)
wg_packet_send_keepalive(peer); wg_packet_send_keepalive(peer);
} }
out:
mutex_unlock(&wg->device_update_lock); mutex_unlock(&wg->device_update_lock);
return 0; return ret;
} }
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev)
list_del(&wg->device_list); list_del(&wg->device_list);
rtnl_unlock(); rtnl_unlock();
mutex_lock(&wg->device_update_lock); mutex_lock(&wg->device_update_lock);
rcu_assign_pointer(wg->creating_net, NULL);
wg->incoming_port = 0; wg->incoming_port = 0;
wg_socket_reinit(wg, NULL, NULL); wg_socket_reinit(wg, NULL, NULL);
/* The final references are cleared in the below calls to destroy_workqueue. */ /* The final references are cleared in the below calls to destroy_workqueue. */
@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev)
skb_queue_purge(&wg->incoming_handshakes); skb_queue_purge(&wg->incoming_handshakes);
free_percpu(dev->tstats); free_percpu(dev->tstats);
free_percpu(wg->incoming_handshakes_worker); free_percpu(wg->incoming_handshakes_worker);
if (wg->have_creating_net_ref)
put_net(wg->creating_net);
kvfree(wg->index_hashtable); kvfree(wg->index_hashtable);
kvfree(wg->peer_hashtable); kvfree(wg->peer_hashtable);
mutex_unlock(&wg->device_update_lock); mutex_unlock(&wg->device_update_lock);
pr_debug("%s: Interface deleted\n", dev->name); pr_debug("%s: Interface destroyed\n", dev->name);
free_netdev(dev); free_netdev(dev);
} }
@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
struct wg_device *wg = netdev_priv(dev); struct wg_device *wg = netdev_priv(dev);
int ret = -ENOMEM; int ret = -ENOMEM;
wg->creating_net = src_net; rcu_assign_pointer(wg->creating_net, src_net);
init_rwsem(&wg->static_identity.lock); init_rwsem(&wg->static_identity.lock);
mutex_init(&wg->socket_update_lock); mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock); mutex_init(&wg->device_update_lock);
@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __read_mostly = {
.newlink = wg_newlink, .newlink = wg_newlink,
}; };
static int wg_netdevice_notification(struct notifier_block *nb, static void wg_netns_exit(struct net *net)
unsigned long action, void *data)
{ {
struct net_device *dev = ((struct netdev_notifier_info *)data)->dev; struct wg_device *wg;
struct wg_device *wg = netdev_priv(dev);
ASSERT_RTNL(); rtnl_lock();
list_for_each_entry(wg, &device_list, device_list) {
if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops) if (rcu_access_pointer(wg->creating_net) == net) {
return 0; pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
netif_carrier_off(wg->dev);
if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) { mutex_lock(&wg->device_update_lock);
put_net(wg->creating_net); rcu_assign_pointer(wg->creating_net, NULL);
wg->have_creating_net_ref = false; wg_socket_reinit(wg, NULL, NULL);
} else if (dev_net(dev) != wg->creating_net && mutex_unlock(&wg->device_update_lock);
!wg->have_creating_net_ref) {
wg->have_creating_net_ref = true;
get_net(wg->creating_net);
} }
return 0; }
rtnl_unlock();
} }
static struct notifier_block netdevice_notifier = { static struct pernet_operations pernet_ops = {
.notifier_call = wg_netdevice_notification .exit = wg_netns_exit
}; };
int __init wg_device_init(void) int __init wg_device_init(void)
@ -429,18 +425,18 @@ int __init wg_device_init(void)
return ret; return ret;
#endif #endif
ret = register_netdevice_notifier(&netdevice_notifier); ret = register_pernet_device(&pernet_ops);
if (ret) if (ret)
goto error_pm; goto error_pm;
ret = rtnl_link_register(&link_ops); ret = rtnl_link_register(&link_ops);
if (ret) if (ret)
goto error_netdevice; goto error_pernet;
return 0; return 0;
error_netdevice: error_pernet:
unregister_netdevice_notifier(&netdevice_notifier); unregister_pernet_device(&pernet_ops);
error_pm: error_pm:
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier); unregister_pm_notifier(&pm_notifier);
@ -451,7 +447,7 @@ int __init wg_device_init(void)
void wg_device_uninit(void) void wg_device_uninit(void)
{ {
rtnl_link_unregister(&link_ops); rtnl_link_unregister(&link_ops);
unregister_netdevice_notifier(&netdevice_notifier); unregister_pernet_device(&pernet_ops);
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
unregister_pm_notifier(&pm_notifier); unregister_pm_notifier(&pm_notifier);
#endif #endif

View file

@ -40,7 +40,7 @@ struct wg_device {
struct net_device *dev; struct net_device *dev;
struct crypt_queue encrypt_queue, decrypt_queue; struct crypt_queue encrypt_queue, decrypt_queue;
struct sock __rcu *sock4, *sock6; struct sock __rcu *sock4, *sock6;
struct net *creating_net; struct net __rcu *creating_net;
struct noise_static_identity static_identity; struct noise_static_identity static_identity;
struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
struct workqueue_struct *packet_crypt_wq; struct workqueue_struct *packet_crypt_wq;
@ -56,7 +56,6 @@ struct wg_device {
unsigned int num_peers, device_update_gen; unsigned int num_peers, device_update_gen;
u32 fwmark; u32 fwmark;
u16 incoming_port; u16 incoming_port;
bool have_creating_net_ref;
}; };
int wg_device_init(void); int wg_device_init(void);

View file

@ -521,11 +521,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
if (flags & ~__WGDEVICE_F_ALL) if (flags & ~__WGDEVICE_F_ALL)
goto out; goto out;
ret = -EPERM; if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
if ((info->attrs[WGDEVICE_A_LISTEN_PORT] || struct net *net;
info->attrs[WGDEVICE_A_FWMARK]) && rcu_read_lock();
!ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN)) net = rcu_dereference(wg->creating_net);
ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
rcu_read_unlock();
if (ret)
goto out; goto out;
}
++wg->device_update_gen; ++wg->device_update_gen;

View file

@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock)
int wg_socket_init(struct wg_device *wg, u16 port) int wg_socket_init(struct wg_device *wg, u16 port)
{ {
struct net *net;
int ret; int ret;
struct udp_tunnel_sock_cfg cfg = { struct udp_tunnel_sock_cfg cfg = {
.sk_user_data = wg, .sk_user_data = wg,
@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port)
}; };
#endif #endif
rcu_read_lock();
net = rcu_dereference(wg->creating_net);
net = net ? maybe_get_net(net) : NULL;
rcu_read_unlock();
if (unlikely(!net))
return -ENONET;
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
retry: retry:
#endif #endif
ret = udp_sock_create(wg->creating_net, &port4, &new4); ret = udp_sock_create(net, &port4, &new4);
if (ret < 0) { if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name); pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
return ret; goto out;
} }
set_sock_opts(new4); set_sock_opts(new4);
setup_udp_tunnel_sock(wg->creating_net, new4, &cfg); setup_udp_tunnel_sock(net, new4, &cfg);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
if (ipv6_mod_enabled()) { if (ipv6_mod_enabled()) {
port6.local_udp_port = inet_sk(new4->sk)->inet_sport; port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
ret = udp_sock_create(wg->creating_net, &port6, &new6); ret = udp_sock_create(net, &port6, &new6);
if (ret < 0) { if (ret < 0) {
udp_tunnel_sock_release(new4); udp_tunnel_sock_release(new4);
if (ret == -EADDRINUSE && !port && retries++ < 100) if (ret == -EADDRINUSE && !port && retries++ < 100)
goto retry; goto retry;
pr_err("%s: Could not create IPv6 socket\n", pr_err("%s: Could not create IPv6 socket\n",
wg->dev->name); wg->dev->name);
return ret; goto out;
} }
set_sock_opts(new6); set_sock_opts(new6);
setup_udp_tunnel_sock(wg->creating_net, new6, &cfg); setup_udp_tunnel_sock(net, new6, &cfg);
} }
#endif #endif
wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL); wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
return 0; ret = 0;
out:
put_net(net);
return ret;
} }
void wg_socket_reinit(struct wg_device *wg, struct sock *new4, void wg_socket_reinit(struct wg_device *wg, struct sock *new4,

View file

@ -587,9 +587,20 @@ ip0 link set wg0 up
kill $ncat_pid kill $ncat_pid
ip0 link del wg0 ip0 link del wg0
# Ensure there aren't circular reference loops
ip1 link add wg1 type wireguard
ip2 link add wg2 type wireguard
ip1 link set wg1 netns $netns2
ip2 link set wg2 netns $netns1
pp ip netns delete $netns1
pp ip netns delete $netns2
pp ip netns add $netns1
pp ip netns add $netns2
sleep 2 # Wait for cleanup and grace periods
declare -A objects declare -A objects
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
done < /dev/kmsg done < /dev/kmsg
alldeleted=1 alldeleted=1