diff --git a/include/linux/mm.h b/include/linux/mm.h
index 78adec4ba9f4..2ec5138badab 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1490,6 +1490,7 @@ struct page *follow_page(struct vm_area_struct *, unsigned long address,
 #define FOLL_DUMP	0x08	/* give error on hole if it would be zero */
 #define FOLL_FORCE	0x10	/* get_user_pages read/write w/o permission */
 #define FOLL_MLOCK	0x40	/* mark page as mlocked */
+#define FOLL_SPLIT	0x80	/* don't return transhuge pages, split them */
 
 typedef int (*pte_fn_t)(pte_t *pte, pgtable_t token, unsigned long addr,
 			void *data);
diff --git a/mm/memory.c b/mm/memory.c
index c1a80e00458d..12ee1ea237f5 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -1305,6 +1305,10 @@ struct page *follow_page(struct vm_area_struct *vma, unsigned long address,
 		goto out;
 	}
 	if (pmd_trans_huge(*pmd)) {
+		if (flags & FOLL_SPLIT) {
+			split_huge_page_pmd(mm, pmd);
+			goto split_fallthrough;
+		}
 		spin_lock(&mm->page_table_lock);
 		if (likely(pmd_trans_huge(*pmd))) {
 			if (unlikely(pmd_trans_splitting(*pmd))) {
@@ -1320,6 +1324,7 @@ struct page *follow_page(struct vm_area_struct *vma, unsigned long address,
 			spin_unlock(&mm->page_table_lock);
 		/* fall through */
 	}
+split_fallthrough:
 	if (unlikely(pmd_bad(*pmd)))
 		goto no_page_table;
 
diff --git a/mm/migrate.c b/mm/migrate.c
index 690d0de993af..1a531b760b3b 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -113,6 +113,8 @@ static int remove_migration_pte(struct page *new, struct vm_area_struct *vma,
 			goto out;
 
 		pmd = pmd_offset(pud, addr);
+		if (pmd_trans_huge(*pmd))
+			goto out;
 		if (!pmd_present(*pmd))
 			goto out;
 
@@ -632,6 +634,9 @@ static int unmap_and_move(new_page_t get_new_page, unsigned long private,
 		/* page was freed from under us. So we are done. */
 		goto move_newpage;
 	}
+	if (unlikely(PageTransHuge(page)))
+		if (unlikely(split_huge_page(page)))
+			goto move_newpage;
 
 	/* prepare cgroup just returns 0 or -ENOMEM */
 	rc = -EAGAIN;
@@ -1063,7 +1068,7 @@ static int do_move_page_to_node_array(struct mm_struct *mm,
 		if (!vma || pp->addr < vma->vm_start || !vma_migratable(vma))
 			goto set_status;
 
-		page = follow_page(vma, pp->addr, FOLL_GET);
+		page = follow_page(vma, pp->addr, FOLL_GET|FOLL_SPLIT);
 
 		err = PTR_ERR(page);
 		if (IS_ERR(page))