diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index cd778b1752b5..3db3d242c3ee 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -51,6 +51,7 @@ struct virtio_balloon
 	u32 pfns[256];
 
 	/* Memory statistics */
+	int need_stats_update;
 	struct virtio_balloon_stat stats[VIRTIO_BALLOON_S_NR];
 };
 
@@ -193,20 +194,30 @@ static void update_balloon_stats(struct virtio_balloon *vb)
  * the stats queue operates in reverse.  The driver initializes the virtqueue
  * with a single buffer.  From that point forward, all conversations consist of
  * a hypervisor request (a call to this function) which directs us to refill
- * the virtqueue with a fresh stats buffer.
+ * the virtqueue with a fresh stats buffer.  Since stats collection can sleep,
+ * we notify our kthread which does the actual work via stats_handle_request().
  */
-static void stats_ack(struct virtqueue *vq)
+static void stats_request(struct virtqueue *vq)
 {
 	struct virtio_balloon *vb;
 	unsigned int len;
-	struct scatterlist sg;
 
 	vb = vq->vq_ops->get_buf(vq, &len);
 	if (!vb)
 		return;
+	vb->need_stats_update = 1;
+	wake_up(&vb->config_change);
+}
 
+static void stats_handle_request(struct virtio_balloon *vb)
+{
+	struct virtqueue *vq;
+	struct scatterlist sg;
+
+	vb->need_stats_update = 0;
 	update_balloon_stats(vb);
 
+	vq = vb->stats_vq;
 	sg_init_one(&sg, vb->stats, sizeof(vb->stats));
 	if (vq->vq_ops->add_buf(vq, &sg, 1, 0, vb) < 0)
 		BUG();
@@ -249,8 +260,11 @@ static int balloon(void *_vballoon)
 		try_to_freeze();
 		wait_event_interruptible(vb->config_change,
 					 (diff = towards_target(vb)) != 0
+					 || vb->need_stats_update
 					 || kthread_should_stop()
 					 || freezing(current));
+		if (vb->need_stats_update)
+			stats_handle_request(vb);
 		if (diff > 0)
 			fill_balloon(vb, diff);
 		else if (diff < 0)
@@ -264,7 +278,7 @@ static int virtballoon_probe(struct virtio_device *vdev)
 {
 	struct virtio_balloon *vb;
 	struct virtqueue *vqs[3];
-	vq_callback_t *callbacks[] = { balloon_ack, balloon_ack, stats_ack };
+	vq_callback_t *callbacks[] = { balloon_ack, balloon_ack, stats_request };
 	const char *names[] = { "inflate", "deflate", "stats" };
 	int err, nvqs;