sock: struct proto hash function may error

In order to support fast reuseport lookups in TCP, the hash function
defined in struct proto must be capable of returning an error code.
This patch changes the function signature of all related hash functions
to return an integer and handles or propagates this return value at
all call sites.

Signed-off-by: Craig Gallek <kraig@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
Craig Gallek 2016-02-10 11:50:35 -05:00 committed by David S. Miller
parent 30c1de08dd
commit 086c653f58
15 changed files with 53 additions and 25 deletions

View file

@ -208,7 +208,7 @@ void inet_hashinfo_init(struct inet_hashinfo *h);
bool inet_ehash_insert(struct sock *sk, struct sock *osk); bool inet_ehash_insert(struct sock *sk, struct sock *osk);
bool inet_ehash_nolisten(struct sock *sk, struct sock *osk); bool inet_ehash_nolisten(struct sock *sk, struct sock *osk);
void __inet_hash(struct sock *sk, struct sock *osk); void __inet_hash(struct sock *sk, struct sock *osk);
void inet_hash(struct sock *sk); int inet_hash(struct sock *sk);
void inet_unhash(struct sock *sk); void inet_unhash(struct sock *sk);
struct sock *__inet_lookup_listener(struct net *net, struct sock *__inet_lookup_listener(struct net *net,

View file

@ -51,7 +51,7 @@ void pn_sock_init(void);
struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *sa); struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *sa);
void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb); void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb);
void phonet_get_local_port_range(int *min, int *max); void phonet_get_local_port_range(int *min, int *max);
void pn_sock_hash(struct sock *sk); int pn_sock_hash(struct sock *sk);
void pn_sock_unhash(struct sock *sk); void pn_sock_unhash(struct sock *sk);
int pn_sock_get_port(struct sock *sk, unsigned short sport); int pn_sock_get_port(struct sock *sk, unsigned short sport);

View file

@ -65,7 +65,7 @@ struct pingfakehdr {
}; };
int ping_get_port(struct sock *sk, unsigned short ident); int ping_get_port(struct sock *sk, unsigned short ident);
void ping_hash(struct sock *sk); int ping_hash(struct sock *sk);
void ping_unhash(struct sock *sk); void ping_unhash(struct sock *sk);
int ping_init_sock(struct sock *sk); int ping_init_sock(struct sock *sk);

View file

@ -57,7 +57,7 @@ int raw_seq_open(struct inode *ino, struct file *file,
#endif #endif
void raw_hash_sk(struct sock *sk); int raw_hash_sk(struct sock *sk);
void raw_unhash_sk(struct sock *sk); void raw_unhash_sk(struct sock *sk);
struct raw_sock { struct raw_sock {

View file

@ -984,7 +984,7 @@ struct proto {
void (*release_cb)(struct sock *sk); void (*release_cb)(struct sock *sk);
/* Keeping track of sk's, looking them up, and port selection methods. */ /* Keeping track of sk's, looking them up, and port selection methods. */
void (*hash)(struct sock *sk); int (*hash)(struct sock *sk);
void (*unhash)(struct sock *sk); void (*unhash)(struct sock *sk);
void (*rehash)(struct sock *sk); void (*rehash)(struct sock *sk);
int (*get_port)(struct sock *sk, unsigned short snum); int (*get_port)(struct sock *sk, unsigned short snum);
@ -1194,10 +1194,10 @@ static inline void sock_prot_inuse_add(struct net *net, struct proto *prot,
/* With per-bucket locks this operation is not-atomic, so that /* With per-bucket locks this operation is not-atomic, so that
* this version is not worse. * this version is not worse.
*/ */
static inline void __sk_prot_rehash(struct sock *sk) static inline int __sk_prot_rehash(struct sock *sk)
{ {
sk->sk_prot->unhash(sk); sk->sk_prot->unhash(sk);
sk->sk_prot->hash(sk); return sk->sk_prot->hash(sk);
} }
void sk_prot_clear_portaddr_nulls(struct sock *sk, int size); void sk_prot_clear_portaddr_nulls(struct sock *sk, int size);

View file

@ -177,9 +177,10 @@ static inline struct udphdr *udp_gro_udphdr(struct sk_buff *skb)
} }
/* hash routines shared between UDPv4/6 and UDP-Litev4/6 */ /* hash routines shared between UDPv4/6 and UDP-Litev4/6 */
static inline void udp_lib_hash(struct sock *sk) static inline int udp_lib_hash(struct sock *sk)
{ {
BUG(); BUG();
return 0;
} }
void udp_lib_unhash(struct sock *sk); void udp_lib_unhash(struct sock *sk);

View file

@ -182,12 +182,14 @@ static int ieee802154_sock_ioctl(struct socket *sock, unsigned int cmd,
static HLIST_HEAD(raw_head); static HLIST_HEAD(raw_head);
static DEFINE_RWLOCK(raw_lock); static DEFINE_RWLOCK(raw_lock);
static void raw_hash(struct sock *sk) static int raw_hash(struct sock *sk)
{ {
write_lock_bh(&raw_lock); write_lock_bh(&raw_lock);
sk_add_node(sk, &raw_head); sk_add_node(sk, &raw_head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
write_unlock_bh(&raw_lock); write_unlock_bh(&raw_lock);
return 0;
} }
static void raw_unhash(struct sock *sk) static void raw_unhash(struct sock *sk)
@ -462,12 +464,14 @@ static inline struct dgram_sock *dgram_sk(const struct sock *sk)
return container_of(sk, struct dgram_sock, sk); return container_of(sk, struct dgram_sock, sk);
} }
static void dgram_hash(struct sock *sk) static int dgram_hash(struct sock *sk)
{ {
write_lock_bh(&dgram_lock); write_lock_bh(&dgram_lock);
sk_add_node(sk, &dgram_head); sk_add_node(sk, &dgram_head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
write_unlock_bh(&dgram_lock); write_unlock_bh(&dgram_lock);
return 0;
} }
static void dgram_unhash(struct sock *sk) static void dgram_unhash(struct sock *sk)
@ -1026,8 +1030,13 @@ static int ieee802154_create(struct net *net, struct socket *sock,
/* Checksums on by default */ /* Checksums on by default */
sock_set_flag(sk, SOCK_ZAPPED); sock_set_flag(sk, SOCK_ZAPPED);
if (sk->sk_prot->hash) if (sk->sk_prot->hash) {
sk->sk_prot->hash(sk); rc = sk->sk_prot->hash(sk);
if (rc) {
sk_common_release(sk);
goto out;
}
}
if (sk->sk_prot->init) { if (sk->sk_prot->init) {
rc = sk->sk_prot->init(sk); rc = sk->sk_prot->init(sk);

View file

@ -370,7 +370,11 @@ static int inet_create(struct net *net, struct socket *sock, int protocol,
*/ */
inet->inet_sport = htons(inet->inet_num); inet->inet_sport = htons(inet->inet_num);
/* Add to protocol hash chains. */ /* Add to protocol hash chains. */
sk->sk_prot->hash(sk); err = sk->sk_prot->hash(sk);
if (err) {
sk_common_release(sk);
goto out;
}
} }
if (sk->sk_prot->init) { if (sk->sk_prot->init) {
@ -1142,8 +1146,7 @@ static int inet_sk_reselect_saddr(struct sock *sk)
* Besides that, it does not check for connection * Besides that, it does not check for connection
* uniqueness. Wait for troubles. * uniqueness. Wait for troubles.
*/ */
__sk_prot_rehash(sk); return __sk_prot_rehash(sk);
return 0;
} }
int inet_sk_rebuild_header(struct sock *sk) int inet_sk_rebuild_header(struct sock *sk)

View file

@ -734,6 +734,7 @@ int inet_csk_listen_start(struct sock *sk, int backlog)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
int err = -EADDRINUSE;
reqsk_queue_alloc(&icsk->icsk_accept_queue); reqsk_queue_alloc(&icsk->icsk_accept_queue);
@ -751,13 +752,14 @@ int inet_csk_listen_start(struct sock *sk, int backlog)
inet->inet_sport = htons(inet->inet_num); inet->inet_sport = htons(inet->inet_num);
sk_dst_reset(sk); sk_dst_reset(sk);
sk->sk_prot->hash(sk); err = sk->sk_prot->hash(sk);
if (likely(!err))
return 0; return 0;
} }
sk->sk_state = TCP_CLOSE; sk->sk_state = TCP_CLOSE;
return -EADDRINUSE; return err;
} }
EXPORT_SYMBOL_GPL(inet_csk_listen_start); EXPORT_SYMBOL_GPL(inet_csk_listen_start);

View file

@ -468,13 +468,15 @@ void __inet_hash(struct sock *sk, struct sock *osk)
} }
EXPORT_SYMBOL(__inet_hash); EXPORT_SYMBOL(__inet_hash);
void inet_hash(struct sock *sk) int inet_hash(struct sock *sk)
{ {
if (sk->sk_state != TCP_CLOSE) { if (sk->sk_state != TCP_CLOSE) {
local_bh_disable(); local_bh_disable();
__inet_hash(sk, NULL); __inet_hash(sk, NULL);
local_bh_enable(); local_bh_enable();
} }
return 0;
} }
EXPORT_SYMBOL_GPL(inet_hash); EXPORT_SYMBOL_GPL(inet_hash);

View file

@ -145,10 +145,12 @@ int ping_get_port(struct sock *sk, unsigned short ident)
} }
EXPORT_SYMBOL_GPL(ping_get_port); EXPORT_SYMBOL_GPL(ping_get_port);
void ping_hash(struct sock *sk) int ping_hash(struct sock *sk)
{ {
pr_debug("ping_hash(sk->port=%u)\n", inet_sk(sk)->inet_num); pr_debug("ping_hash(sk->port=%u)\n", inet_sk(sk)->inet_num);
BUG(); /* "Please do not press this button again." */ BUG(); /* "Please do not press this button again." */
return 0;
} }
void ping_unhash(struct sock *sk) void ping_unhash(struct sock *sk)

View file

@ -93,7 +93,7 @@ static struct raw_hashinfo raw_v4_hashinfo = {
.lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock), .lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock),
}; };
void raw_hash_sk(struct sock *sk) int raw_hash_sk(struct sock *sk)
{ {
struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; struct raw_hashinfo *h = sk->sk_prot->h.raw_hash;
struct hlist_head *head; struct hlist_head *head;
@ -104,6 +104,8 @@ void raw_hash_sk(struct sock *sk)
sk_add_node(sk, head); sk_add_node(sk, head);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
write_unlock_bh(&h->lock); write_unlock_bh(&h->lock);
return 0;
} }
EXPORT_SYMBOL_GPL(raw_hash_sk); EXPORT_SYMBOL_GPL(raw_hash_sk);

View file

@ -235,7 +235,11 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol,
* creation time automatically shares. * creation time automatically shares.
*/ */
inet->inet_sport = htons(inet->inet_num); inet->inet_sport = htons(inet->inet_num);
sk->sk_prot->hash(sk); err = sk->sk_prot->hash(sk);
if (err) {
sk_common_release(sk);
goto out;
}
} }
if (sk->sk_prot->init) { if (sk->sk_prot->init) {
err = sk->sk_prot->init(sk); err = sk->sk_prot->init(sk);

View file

@ -140,13 +140,15 @@ void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb)
rcu_read_unlock(); rcu_read_unlock();
} }
void pn_sock_hash(struct sock *sk) int pn_sock_hash(struct sock *sk)
{ {
struct hlist_head *hlist = pn_hash_list(pn_sk(sk)->sobject); struct hlist_head *hlist = pn_hash_list(pn_sk(sk)->sobject);
mutex_lock(&pnsocks.lock); mutex_lock(&pnsocks.lock);
sk_add_node_rcu(sk, hlist); sk_add_node_rcu(sk, hlist);
mutex_unlock(&pnsocks.lock); mutex_unlock(&pnsocks.lock);
return 0;
} }
EXPORT_SYMBOL(pn_sock_hash); EXPORT_SYMBOL(pn_sock_hash);
@ -200,7 +202,7 @@ static int pn_socket_bind(struct socket *sock, struct sockaddr *addr, int len)
pn->resource = spn->spn_resource; pn->resource = spn->spn_resource;
/* Enable RX on the socket */ /* Enable RX on the socket */
sk->sk_prot->hash(sk); err = sk->sk_prot->hash(sk);
out_port: out_port:
mutex_unlock(&port_mutex); mutex_unlock(&port_mutex);
out: out:

View file

@ -6101,9 +6101,10 @@ static int sctp_getsockopt(struct sock *sk, int level, int optname,
return retval; return retval;
} }
static void sctp_hash(struct sock *sk) static int sctp_hash(struct sock *sk)
{ {
/* STUB */ /* STUB */
return 0;
} }
static void sctp_unhash(struct sock *sk) static void sctp_unhash(struct sock *sk)