diff --git a/include/linux/netlink.h b/include/linux/netlink.h index 0fbecbbe8e9e..080f6ba9e73a 100644 --- a/include/linux/netlink.h +++ b/include/linux/netlink.h @@ -176,12 +176,16 @@ struct netlink_skb_parms #define NETLINK_CREDS(skb) (&NETLINK_CB((skb)).creds) +extern void netlink_table_grab(void); +extern void netlink_table_ungrab(void); + extern struct sock *netlink_kernel_create(struct net *net, int unit,unsigned int groups, void (*input)(struct sk_buff *skb), struct mutex *cb_mutex, struct module *module); extern void netlink_kernel_release(struct sock *sk); +extern int __netlink_change_ngroups(struct sock *sk, unsigned int groups); extern int netlink_change_ngroups(struct sock *sk, unsigned int groups); extern void netlink_clear_multicast_users(struct sock *sk, unsigned int group); extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err); diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index d0ff382c40ca..c5aab6a368ce 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -177,9 +177,11 @@ static void netlink_sock_destruct(struct sock *sk) * this, _but_ remember, it adds useless work on UP machines. */ -static void netlink_table_grab(void) +void netlink_table_grab(void) __acquires(nl_table_lock) { + might_sleep(); + write_lock_irq(&nl_table_lock); if (atomic_read(&nl_table_users)) { @@ -200,7 +202,7 @@ static void netlink_table_grab(void) } } -static void netlink_table_ungrab(void) +void netlink_table_ungrab(void) __releases(nl_table_lock) { write_unlock_irq(&nl_table_lock); @@ -1549,37 +1551,21 @@ static void netlink_free_old_listeners(struct rcu_head *rcu_head) kfree(lrh->ptr); } -/** - * netlink_change_ngroups - change number of multicast groups - * - * This changes the number of multicast groups that are available - * on a certain netlink family. Note that it is not possible to - * change the number of groups to below 32. Also note that it does - * not implicitly call netlink_clear_multicast_users() when the - * number of groups is reduced. - * - * @sk: The kernel netlink socket, as returned by netlink_kernel_create(). - * @groups: The new number of groups. - */ -int netlink_change_ngroups(struct sock *sk, unsigned int groups) +int __netlink_change_ngroups(struct sock *sk, unsigned int groups) { unsigned long *listeners, *old = NULL; struct listeners_rcu_head *old_rcu_head; struct netlink_table *tbl = &nl_table[sk->sk_protocol]; - int err = 0; if (groups < 32) groups = 32; - netlink_table_grab(); if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) { listeners = kzalloc(NLGRPSZ(groups) + sizeof(struct listeners_rcu_head), GFP_ATOMIC); - if (!listeners) { - err = -ENOMEM; - goto out_ungrab; - } + if (!listeners) + return -ENOMEM; old = tbl->listeners; memcpy(listeners, old, NLGRPSZ(tbl->groups)); rcu_assign_pointer(tbl->listeners, listeners); @@ -1597,8 +1583,29 @@ int netlink_change_ngroups(struct sock *sk, unsigned int groups) } tbl->groups = groups; - out_ungrab: + return 0; +} + +/** + * netlink_change_ngroups - change number of multicast groups + * + * This changes the number of multicast groups that are available + * on a certain netlink family. Note that it is not possible to + * change the number of groups to below 32. Also note that it does + * not implicitly call netlink_clear_multicast_users() when the + * number of groups is reduced. + * + * @sk: The kernel netlink socket, as returned by netlink_kernel_create(). + * @groups: The new number of groups. + */ +int netlink_change_ngroups(struct sock *sk, unsigned int groups) +{ + int err; + + netlink_table_grab(); + err = __netlink_change_ngroups(sk, groups); netlink_table_ungrab(); + return err; } diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c index 66f6ba0bab11..566941e03363 100644 --- a/net/netlink/genetlink.c +++ b/net/netlink/genetlink.c @@ -176,9 +176,10 @@ int genl_register_mc_group(struct genl_family *family, if (family->netnsok) { struct net *net; + netlink_table_grab(); rcu_read_lock(); for_each_net_rcu(net) { - err = netlink_change_ngroups(net->genl_sock, + err = __netlink_change_ngroups(net->genl_sock, mc_groups_longs * BITS_PER_LONG); if (err) { /* @@ -188,10 +189,12 @@ int genl_register_mc_group(struct genl_family *family, * increased on some sockets which is ok. */ rcu_read_unlock(); + netlink_table_ungrab(); goto out; } } rcu_read_unlock(); + netlink_table_ungrab(); } else { err = netlink_change_ngroups(init_net.genl_sock, mc_groups_longs * BITS_PER_LONG);