PCI: hv: Add validation for untrusted Hyper-V values
authorAndrea Parri (Microsoft) <parri.andrea@gmail.com>
Wed, 11 May 2022 22:32:06 +0000 (00:32 +0200)
committerWei Liu <wei.liu@kernel.org>
Fri, 13 May 2022 16:57:32 +0000 (16:57 +0000)
For additional robustness in the face of Hyper-V errors or malicious
behavior, validate all values that originate from packets that Hyper-V
has sent to the guest in the host-to-guest ring buffer.  Ensure that
invalid values cannot cause data being copied out of the bounds of the
source buffer in hv_pci_onchannelcallback().

While at it, remove a redundant validation in hv_pci_generic_compl():
hv_pci_onchannelcallback() already ensures that all processed incoming
packets are "at least as large as [in fact larger than] a response".

Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
Reviewed-by: Michael Kelley <mikelley@microsoft.com>
Acked-by: Lorenzo Pieralisi <lorenzo.pieralisi@arm.com>
Link: https://lore.kernel.org/r/20220511223207.3386-2-parri.andrea@gmail.com
Signed-off-by: Wei Liu <wei.liu@kernel.org>
drivers/pci/controller/pci-hyperv.c

index e439b81..a06e2cf 100644 (file)
@@ -981,11 +981,7 @@ static void hv_pci_generic_compl(void *context, struct pci_response *resp,
 {
        struct hv_pci_compl *comp_pkt = context;
 
-       if (resp_packet_size >= offsetofend(struct pci_response, status))
-               comp_pkt->completion_status = resp->status;
-       else
-               comp_pkt->completion_status = -1;
-
+       comp_pkt->completion_status = resp->status;
        complete(&comp_pkt->host_event);
 }
 
@@ -1606,8 +1602,13 @@ static void hv_pci_compose_compl(void *context, struct pci_response *resp,
        struct pci_create_int_response *int_resp =
                (struct pci_create_int_response *)resp;
 
+       if (resp_packet_size < sizeof(*int_resp)) {
+               comp_pkt->comp_pkt.completion_status = -1;
+               goto out;
+       }
        comp_pkt->comp_pkt.completion_status = resp->status;
        comp_pkt->int_desc = int_resp->int_desc;
+out:
        complete(&comp_pkt->comp_pkt.host_event);
 }
 
@@ -2291,12 +2292,14 @@ static void q_resource_requirements(void *context, struct pci_response *resp,
        struct q_res_req_compl *completion = context;
        struct pci_q_res_req_response *q_res_req =
                (struct pci_q_res_req_response *)resp;
+       s32 status;
        int i;
 
-       if (resp->status < 0) {
+       status = (resp_packet_size < sizeof(*q_res_req)) ? -1 : resp->status;
+       if (status < 0) {
                dev_err(&completion->hpdev->hbus->hdev->device,
                        "query resource requirements failed: %x\n",
-                       resp->status);
+                       status);
        } else {
                for (i = 0; i < PCI_STD_NUM_BARS; i++) {
                        completion->hpdev->probed_bar[i] =
@@ -2848,7 +2851,8 @@ static void hv_pci_onchannelcallback(void *context)
                        case PCI_BUS_RELATIONS:
 
                                bus_rel = (struct pci_bus_relations *)buffer;
-                               if (bytes_recvd <
+                               if (bytes_recvd < sizeof(*bus_rel) ||
+                                   bytes_recvd <
                                        struct_size(bus_rel, func,
                                                    bus_rel->device_count)) {
                                        dev_err(&hbus->hdev->device,
@@ -2862,7 +2866,8 @@ static void hv_pci_onchannelcallback(void *context)
                        case PCI_BUS_RELATIONS2:
 
                                bus_rel2 = (struct pci_bus_relations2 *)buffer;
-                               if (bytes_recvd <
+                               if (bytes_recvd < sizeof(*bus_rel2) ||
+                                   bytes_recvd <
                                        struct_size(bus_rel2, func,
                                                    bus_rel2->device_count)) {
                                        dev_err(&hbus->hdev->device,
@@ -2876,6 +2881,11 @@ static void hv_pci_onchannelcallback(void *context)
                        case PCI_EJECT:
 
                                dev_message = (struct pci_dev_incoming *)buffer;
+                               if (bytes_recvd < sizeof(*dev_message)) {
+                                       dev_err(&hbus->hdev->device,
+                                               "eject message too small\n");
+                                       break;
+                               }
                                hpdev = get_pcichild_wslot(hbus,
                                                      dev_message->wslot.slot);
                                if (hpdev) {
@@ -2887,6 +2897,11 @@ static void hv_pci_onchannelcallback(void *context)
                        case PCI_INVALIDATE_BLOCK:
 
                                inval = (struct pci_dev_inval_block *)buffer;
+                               if (bytes_recvd < sizeof(*inval)) {
+                                       dev_err(&hbus->hdev->device,
+                                               "invalidate message too small\n");
+                                       break;
+                               }
                                hpdev = get_pcichild_wslot(hbus,
                                                           inval->wslot.slot);
                                if (hpdev) {