diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index 6e9bebcdd5a6..10b22619e7ec 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -1179,6 +1179,15 @@ static void update_exception_bitmap(struct kvm_vcpu *vcpu)
 		eb &= ~(1u << PF_VECTOR); /* bypass_guest_pf = 0 */
 	if (vcpu->fpu_active)
 		eb &= ~(1u << NM_VECTOR);
+
+	/* When we are running a nested L2 guest and L1 specified for it a
+	 * certain exception bitmap, we must trap the same exceptions and pass
+	 * them to L1. When running L2, we will only handle the exceptions
+	 * specified above if L1 did not want them.
+	 */
+	if (is_guest_mode(vcpu))
+		eb |= get_vmcs12(vcpu)->exception_bitmap;
+
 	vmcs_write32(EXCEPTION_BITMAP, eb);
 }
 
@@ -1473,6 +1482,9 @@ static void vmx_fpu_activate(struct kvm_vcpu *vcpu)
 	vmcs_writel(GUEST_CR0, cr0);
 	update_exception_bitmap(vcpu);
 	vcpu->arch.cr0_guest_owned_bits = X86_CR0_TS;
+	if (is_guest_mode(vcpu))
+		vcpu->arch.cr0_guest_owned_bits &=
+			~get_vmcs12(vcpu)->cr0_guest_host_mask;
 	vmcs_writel(CR0_GUEST_HOST_MASK, ~vcpu->arch.cr0_guest_owned_bits);
 }
 
@@ -1496,12 +1508,29 @@ static inline unsigned long nested_read_cr4(struct vmcs12 *fields)
 
 static void vmx_fpu_deactivate(struct kvm_vcpu *vcpu)
 {
+	/* Note that there is no vcpu->fpu_active = 0 here. The caller must
+	 * set this *before* calling this function.
+	 */
 	vmx_decache_cr0_guest_bits(vcpu);
 	vmcs_set_bits(GUEST_CR0, X86_CR0_TS | X86_CR0_MP);
 	update_exception_bitmap(vcpu);
 	vcpu->arch.cr0_guest_owned_bits = 0;
 	vmcs_writel(CR0_GUEST_HOST_MASK, ~vcpu->arch.cr0_guest_owned_bits);
-	vmcs_writel(CR0_READ_SHADOW, vcpu->arch.cr0);
+	if (is_guest_mode(vcpu)) {
+		/*
+		 * L1's specified read shadow might not contain the TS bit,
+		 * so now that we turned on shadowing of this bit, we need to
+		 * set this bit of the shadow. Like in nested_vmx_run we need
+		 * nested_read_cr0(vmcs12), but vmcs12->guest_cr0 is not yet
+		 * up-to-date here because we just decached cr0.TS (and we'll
+		 * only update vmcs12->guest_cr0 on nested exit).
+		 */
+		struct vmcs12 *vmcs12 = get_vmcs12(vcpu);
+		vmcs12->guest_cr0 = (vmcs12->guest_cr0 & ~X86_CR0_TS) |
+			(vcpu->arch.cr0 & X86_CR0_TS);
+		vmcs_writel(CR0_READ_SHADOW, nested_read_cr0(vmcs12));
+	} else
+		vmcs_writel(CR0_READ_SHADOW, vcpu->arch.cr0);
 }
 
 static unsigned long vmx_get_rflags(struct kvm_vcpu *vcpu)