diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c index 1f97b559892a..0a0f2ec75370 100644 --- a/kernel/bpf/sockmap.c +++ b/kernel/bpf/sockmap.c @@ -132,6 +132,7 @@ struct smap_psock { struct work_struct gc_work; struct proto *sk_proto; + void (*save_unhash)(struct sock *sk); void (*save_close)(struct sock *sk, long timeout); void (*save_data_ready)(struct sock *sk); void (*save_write_space)(struct sock *sk); @@ -143,6 +144,7 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); static int bpf_tcp_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags); +static void bpf_tcp_unhash(struct sock *sk); static void bpf_tcp_close(struct sock *sk, long timeout); static inline struct smap_psock *smap_psock_sk(const struct sock *sk) @@ -184,6 +186,7 @@ static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS], struct proto *base) { prot[SOCKMAP_BASE] = *base; + prot[SOCKMAP_BASE].unhash = bpf_tcp_unhash; prot[SOCKMAP_BASE].close = bpf_tcp_close; prot[SOCKMAP_BASE].recvmsg = bpf_tcp_recvmsg; prot[SOCKMAP_BASE].stream_memory_read = bpf_tcp_stream_read; @@ -217,6 +220,7 @@ static int bpf_tcp_init(struct sock *sk) return -EBUSY; } + psock->save_unhash = sk->sk_prot->unhash; psock->save_close = sk->sk_prot->close; psock->sk_proto = sk->sk_prot; @@ -305,30 +309,12 @@ static struct smap_psock_map_entry *psock_map_pop(struct sock *sk, return e; } -static void bpf_tcp_close(struct sock *sk, long timeout) +static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock) { - void (*close_fun)(struct sock *sk, long timeout); struct smap_psock_map_entry *e; struct sk_msg_buff *md, *mtmp; - struct smap_psock *psock; struct sock *osk; - lock_sock(sk); - rcu_read_lock(); - psock = smap_psock_sk(sk); - if (unlikely(!psock)) { - rcu_read_unlock(); - release_sock(sk); - return sk->sk_prot->close(sk, timeout); - } - - /* The psock may be destroyed anytime after exiting the RCU critial - * section so by the time we use close_fun the psock may no longer - * be valid. However, bpf_tcp_close is called with the sock lock - * held so the close hook and sk are still valid. - */ - close_fun = psock->save_close; - if (psock->cork) { free_start_sg(psock->sock, psock->cork, true); kfree(psock->cork); @@ -379,6 +365,42 @@ static void bpf_tcp_close(struct sock *sk, long timeout) kfree(e); e = psock_map_pop(sk, psock); } +} + +static void bpf_tcp_unhash(struct sock *sk) +{ + void (*unhash_fun)(struct sock *sk); + struct smap_psock *psock; + + rcu_read_lock(); + psock = smap_psock_sk(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + if (sk->sk_prot->unhash) + sk->sk_prot->unhash(sk); + return; + } + unhash_fun = psock->save_unhash; + bpf_tcp_remove(sk, psock); + rcu_read_unlock(); + unhash_fun(sk); +} + +static void bpf_tcp_close(struct sock *sk, long timeout) +{ + void (*close_fun)(struct sock *sk, long timeout); + struct smap_psock *psock; + + lock_sock(sk); + rcu_read_lock(); + psock = smap_psock_sk(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + release_sock(sk); + return sk->sk_prot->close(sk, timeout); + } + close_fun = psock->save_close; + bpf_tcp_remove(sk, psock); rcu_read_unlock(); release_sock(sk); close_fun(sk, timeout);