diff --git a/drivers/pci/pci.c b/drivers/pci/pci.c
index 5cb5820fae40..8b47f70b7d8f 100644
--- a/drivers/pci/pci.c
+++ b/drivers/pci/pci.c
@@ -2069,6 +2069,9 @@ void pci_free_cap_save_buffers(struct pci_dev *dev)
 /**
  * pci_enable_ari - enable ARI forwarding if hardware support it
  * @dev: the PCI device
+ *
+ * If @dev and its upstream bridge both support ARI, enable ARI in the
+ * bridge.  Otherwise, disable ARI in the bridge.
  */
 void pci_enable_ari(struct pci_dev *dev)
 {
@@ -2078,9 +2081,6 @@ void pci_enable_ari(struct pci_dev *dev)
 	if (pcie_ari_disabled || !pci_is_pcie(dev) || dev->devfn)
 		return;
 
-	if (!pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI))
-		return;
-
 	bridge = dev->bus->self;
 	if (!bridge)
 		return;
@@ -2089,8 +2089,15 @@ void pci_enable_ari(struct pci_dev *dev)
 	if (!(cap & PCI_EXP_DEVCAP2_ARI))
 		return;
 
-	pcie_capability_set_word(bridge, PCI_EXP_DEVCTL2, PCI_EXP_DEVCTL2_ARI);
-	bridge->ari_enabled = 1;
+	if (pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI)) {
+		pcie_capability_set_word(bridge, PCI_EXP_DEVCTL2,
+					 PCI_EXP_DEVCTL2_ARI);
+		bridge->ari_enabled = 1;
+	} else {
+		pcie_capability_clear_word(bridge, PCI_EXP_DEVCTL2,
+					   PCI_EXP_DEVCTL2_ARI);
+		bridge->ari_enabled = 0;
+	}
 }
 
 /**