diff --git a/fs/ecryptfs/miscdev.c b/fs/ecryptfs/miscdev.c index 940a82e63dc3..0dc5a3d554a4 100644 --- a/fs/ecryptfs/miscdev.c +++ b/fs/ecryptfs/miscdev.c @@ -409,11 +409,47 @@ ecryptfs_miscdev_write(struct file *file, const char __user *buf, ssize_t sz = 0; char *data; uid_t euid = current_euid(); + unsigned char packet_size_peek[3]; int rc; - if (count == 0) + if (count == 0) { goto out; + } else if (count == (1 + 4)) { + /* Likely a harmless MSG_HELO or MSG_QUIT - no packet length */ + goto memdup; + } else if (count < (1 + 4 + 1) + || count > (1 + 4 + 2 + sizeof(struct ecryptfs_message) + 4 + + ECRYPTFS_MAX_ENCRYPTED_KEY_BYTES)) { + printk(KERN_WARNING "%s: Acceptable packet size range is " + "[%d-%lu], but amount of data written is [%zu].", + __func__, (1 + 4 + 1), + (1 + 4 + 2 + sizeof(struct ecryptfs_message) + 4 + + ECRYPTFS_MAX_ENCRYPTED_KEY_BYTES), count); + return -EINVAL; + } + if (copy_from_user(packet_size_peek, (buf + 1 + 4), + sizeof(packet_size_peek))) { + printk(KERN_WARNING "%s: Error while inspecting packet size\n", + __func__); + return -EFAULT; + } + + rc = ecryptfs_parse_packet_length(packet_size_peek, &packet_size, + &packet_size_length); + if (rc) { + printk(KERN_WARNING "%s: Error parsing packet length; " + "rc = [%d]\n", __func__, rc); + return rc; + } + + if ((1 + 4 + packet_size_length + packet_size) != count) { + printk(KERN_WARNING "%s: Invalid packet size [%zu]\n", __func__, + packet_size); + return -EINVAL; + } + +memdup: data = memdup_user(buf, count); if (IS_ERR(data)) { printk(KERN_ERR "%s: memdup_user returned error [%ld]\n", @@ -435,23 +471,7 @@ ecryptfs_miscdev_write(struct file *file, const char __user *buf, } memcpy(&counter_nbo, &data[i], 4); seq = be32_to_cpu(counter_nbo); - i += 4; - rc = ecryptfs_parse_packet_length(&data[i], &packet_size, - &packet_size_length); - if (rc) { - printk(KERN_WARNING "%s: Error parsing packet length; " - "rc = [%d]\n", __func__, rc); - goto out_free; - } - i += packet_size_length; - if ((1 + 4 + packet_size_length + packet_size) != count) { - printk(KERN_WARNING "%s: (1 + packet_size_length([%zd])" - " + packet_size([%zd]))([%zd]) != " - "count([%zd]). Invalid packet format.\n", - __func__, packet_size_length, packet_size, - (1 + packet_size_length + packet_size), count); - goto out_free; - } + i += 4 + packet_size_length; rc = ecryptfs_miscdev_response(&data[i], packet_size, euid, current_user_ns(), task_pid(current), seq);