ptp: ocp: Add firmware header checks
authorVadim Fedorenko <vadfed@fb.com>
Thu, 19 May 2022 21:21:53 +0000 (14:21 -0700)
committerJakub Kicinski <kuba@kernel.org>
Tue, 24 May 2022 18:39:54 +0000 (11:39 -0700)
Right now it's possible to flash any kind of binary via devlink and
break the card easily. This diff adds an optional header check when
installing the firmware.

Signed-off-by: Vadim Fedorenko <vadfed@fb.com>
Signed-off-by: Jonathan Lemon <jonathan.lemon@gmail.com>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/ptp/ptp_ocp.c

index 9fac96f..4519ef4 100644 (file)
@@ -19,6 +19,7 @@
 #include <linux/i2c.h>
 #include <linux/mtd/mtd.h>
 #include <linux/nvmem-consumer.h>
+#include <linux/crc16.h>
 
 #define PCI_VENDOR_ID_FACEBOOK                 0x1d9b
 #define PCI_DEVICE_ID_FACEBOOK_TIMECARD                0x0400
@@ -213,6 +214,17 @@ struct ptp_ocp_flash_info {
        void *data;
 };
 
+struct ptp_ocp_firmware_header {
+       char magic[4];
+       __be16 pci_vendor_id;
+       __be16 pci_device_id;
+       __be32 image_size;
+       __be16 hw_revision;
+       __be16 crc;
+};
+
+#define OCP_FIRMWARE_MAGIC_HEADER "OCPC"
+
 struct ptp_ocp_i2c_info {
        const char *name;
        unsigned long fixed_rate;
@@ -1324,24 +1336,80 @@ ptp_ocp_find_flash(struct ptp_ocp *bp)
 }
 
 static int
+ptp_ocp_devlink_fw_image(struct devlink *devlink, const struct firmware *fw,
+                        const u8 **data, size_t *size)
+{
+       struct ptp_ocp *bp = devlink_priv(devlink);
+       const struct ptp_ocp_firmware_header *hdr;
+       size_t offset, length;
+       u16 crc;
+
+       hdr = (const struct ptp_ocp_firmware_header *)fw->data;
+       if (memcmp(hdr->magic, OCP_FIRMWARE_MAGIC_HEADER, 4)) {
+               devlink_flash_update_status_notify(devlink,
+                       "No firmware header found, flashing raw image",
+                       NULL, 0, 0);
+               offset = 0;
+               length = fw->size;
+               goto out;
+       }
+
+       if (be16_to_cpu(hdr->pci_vendor_id) != bp->pdev->vendor ||
+           be16_to_cpu(hdr->pci_device_id) != bp->pdev->device) {
+               devlink_flash_update_status_notify(devlink,
+                       "Firmware image compatibility check failed",
+                       NULL, 0, 0);
+               return -EINVAL;
+       }
+
+       offset = sizeof(*hdr);
+       length = be32_to_cpu(hdr->image_size);
+       if (length != (fw->size - offset)) {
+               devlink_flash_update_status_notify(devlink,
+                       "Firmware image size check failed",
+                       NULL, 0, 0);
+               return -EINVAL;
+       }
+
+       crc = crc16(0xffff, &fw->data[offset], length);
+       if (be16_to_cpu(hdr->crc) != crc) {
+               devlink_flash_update_status_notify(devlink,
+                       "Firmware image CRC check failed",
+                       NULL, 0, 0);
+               return -EINVAL;
+       }
+
+out:
+       *data = &fw->data[offset];
+       *size = length;
+
+       return 0;
+}
+
+static int
 ptp_ocp_devlink_flash(struct devlink *devlink, struct device *dev,
                      const struct firmware *fw)
 {
        struct mtd_info *mtd = dev_get_drvdata(dev);
        struct ptp_ocp *bp = devlink_priv(devlink);
-       size_t off, len, resid, wrote;
+       size_t off, len, size, resid, wrote;
        struct erase_info erase;
        size_t base, blksz;
-       int err = 0;
+       const u8 *data;
+       int err;
+
+       err = ptp_ocp_devlink_fw_image(devlink, fw, &data, &size);
+       if (err)
+               goto out;
 
        off = 0;
        base = bp->flash_start;
        blksz = 4096;
-       resid = fw->size;
+       resid = size;
 
        while (resid) {
                devlink_flash_update_status_notify(devlink, "Flashing",
-                                                  NULL, off, fw->size);
+                                                  NULL, off, size);
 
                len = min_t(size_t, resid, blksz);
                erase.addr = base + off;
@@ -1351,7 +1419,7 @@ ptp_ocp_devlink_flash(struct devlink *devlink, struct device *dev,
                if (err)
                        goto out;
 
-               err = mtd_write(mtd, base + off, len, &wrote, &fw->data[off]);
+               err = mtd_write(mtd, base + off, len, &wrote, data + off);
                if (err)
                        goto out;