diff --git a/drivers/hv/channel_mgmt.c b/drivers/hv/channel_mgmt.c index 38b682bab85a..b6c1211b4df7 100644 --- a/drivers/hv/channel_mgmt.c +++ b/drivers/hv/channel_mgmt.c @@ -597,27 +597,55 @@ static void init_vp_index(struct vmbus_channel *channel, u16 dev_type) static void vmbus_wait_for_unload(void) { - int cpu = smp_processor_id(); - void *page_addr = hv_context.synic_message_page[cpu]; - struct hv_message *msg = (struct hv_message *)page_addr + - VMBUS_MESSAGE_SINT; + int cpu; + void *page_addr; + struct hv_message *msg; struct vmbus_channel_message_header *hdr; - bool unloaded = false; + u32 message_type; + /* + * CHANNELMSG_UNLOAD_RESPONSE is always delivered to the CPU which was + * used for initial contact or to CPU0 depending on host version. When + * we're crashing on a different CPU let's hope that IRQ handler on + * the cpu which receives CHANNELMSG_UNLOAD_RESPONSE is still + * functional and vmbus_unload_response() will complete + * vmbus_connection.unload_event. If not, the last thing we can do is + * read message pages for all CPUs directly. + */ while (1) { - if (READ_ONCE(msg->header.message_type) == HVMSG_NONE) { - mdelay(10); - continue; + if (completion_done(&vmbus_connection.unload_event)) + break; + + for_each_online_cpu(cpu) { + page_addr = hv_context.synic_message_page[cpu]; + msg = (struct hv_message *)page_addr + + VMBUS_MESSAGE_SINT; + + message_type = READ_ONCE(msg->header.message_type); + if (message_type == HVMSG_NONE) + continue; + + hdr = (struct vmbus_channel_message_header *) + msg->u.payload; + + if (hdr->msgtype == CHANNELMSG_UNLOAD_RESPONSE) + complete(&vmbus_connection.unload_event); + + vmbus_signal_eom(msg, message_type); } - hdr = (struct vmbus_channel_message_header *)msg->u.payload; - if (hdr->msgtype == CHANNELMSG_UNLOAD_RESPONSE) - unloaded = true; + mdelay(10); + } - vmbus_signal_eom(msg); - - if (unloaded) - break; + /* + * We're crashing and already got the UNLOAD_RESPONSE, cleanup all + * maybe-pending messages on all CPUs to be able to receive new + * messages after we reconnect. + */ + for_each_online_cpu(cpu) { + page_addr = hv_context.synic_message_page[cpu]; + msg = (struct hv_message *)page_addr + VMBUS_MESSAGE_SINT; + msg->header.message_type = HVMSG_NONE; } } diff --git a/drivers/hv/hyperv_vmbus.h b/drivers/hv/hyperv_vmbus.h index e5203e4d6782..718b5c72f0c8 100644 --- a/drivers/hv/hyperv_vmbus.h +++ b/drivers/hv/hyperv_vmbus.h @@ -625,9 +625,21 @@ extern struct vmbus_channel_message_table_entry channel_message_table[CHANNELMSG_COUNT]; /* Free the message slot and signal end-of-message if required */ -static inline void vmbus_signal_eom(struct hv_message *msg) +static inline void vmbus_signal_eom(struct hv_message *msg, u32 old_msg_type) { - msg->header.message_type = HVMSG_NONE; + /* + * On crash we're reading some other CPU's message page and we need + * to be careful: this other CPU may already had cleared the header + * and the host may already had delivered some other message there. + * In case we blindly write msg->header.message_type we're going + * to lose it. We can still lose a message of the same type but + * we count on the fact that there can only be one + * CHANNELMSG_UNLOAD_RESPONSE and we don't care about other messages + * on crash. + */ + if (cmpxchg(&msg->header.message_type, old_msg_type, + HVMSG_NONE) != old_msg_type) + return; /* * Make sure the write to MessageType (ie set to diff --git a/drivers/hv/vmbus_drv.c b/drivers/hv/vmbus_drv.c index a29a6c0abc01..952f20fdc7e3 100644 --- a/drivers/hv/vmbus_drv.c +++ b/drivers/hv/vmbus_drv.c @@ -712,7 +712,7 @@ static void hv_process_timer_expiration(struct hv_message *msg, int cpu) if (dev->event_handler) dev->event_handler(dev); - vmbus_signal_eom(msg); + vmbus_signal_eom(msg, HVMSG_TIMER_EXPIRED); } void vmbus_on_msg_dpc(unsigned long data) @@ -724,8 +724,9 @@ void vmbus_on_msg_dpc(unsigned long data) struct vmbus_channel_message_header *hdr; struct vmbus_channel_message_table_entry *entry; struct onmessage_work_context *ctx; + u32 message_type = msg->header.message_type; - if (msg->header.message_type == HVMSG_NONE) + if (message_type == HVMSG_NONE) /* no msg */ return; @@ -750,7 +751,7 @@ void vmbus_on_msg_dpc(unsigned long data) entry->message_handler(hdr); msg_handled: - vmbus_signal_eom(msg); + vmbus_signal_eom(msg, message_type); } static void vmbus_isr(void)