tftp: prevent overwriting reserved memory
[platform/kernel/u-boot.git] / net / tftp.c
index 18ce84c..a9335b1 100644 (file)
@@ -8,21 +8,24 @@
 
 #include <common.h>
 #include <command.h>
+#include <efi_loader.h>
 #include <mapmem.h>
 #include <net.h>
-#include "tftp.h"
+#include <net/tftp.h>
 #include "bootp.h"
 #ifdef CONFIG_SYS_DIRECT_FLASH_TFTP
 #include <flash.h>
 #endif
 
+DECLARE_GLOBAL_DATA_PTR;
+
 /* Well known TFTP port # */
 #define WELL_KNOWN_PORT        69
 /* Millisecs to timeout for lost pkt */
-#define TIMEOUT                100UL
+#define TIMEOUT                5000UL
 #ifndef        CONFIG_NET_RETRY_COUNT
 /* # of timeouts before giving up */
-# define TIMEOUT_COUNT 1000
+# define TIMEOUT_COUNT 10
 #else
 # define TIMEOUT_COUNT  (CONFIG_NET_RETRY_COUNT * 2)
 #endif
@@ -80,6 +83,10 @@ static ulong tftp_block_wrap;
 /* memory offset due to wrapping */
 static ulong   tftp_block_wrap_offset;
 static int     tftp_state;
+static ulong   tftp_load_addr;
+#ifdef CONFIG_LMB
+static ulong   tftp_load_size;
+#endif
 #ifdef CONFIG_TFTP_TSIZE
 /* The file size reported by the server */
 static int     tftp_tsize;
@@ -163,10 +170,11 @@ static void mcast_cleanup(void)
 
 #endif /* CONFIG_MCAST_TFTP */
 
-static inline void store_block(int block, uchar *src, unsigned len)
+static inline int store_block(int block, uchar *src, unsigned int len)
 {
        ulong offset = block * tftp_block_size + tftp_block_wrap_offset;
        ulong newsize = offset + len;
+       ulong store_addr = tftp_load_addr + offset;
 #ifdef CONFIG_SYS_DIRECT_FLASH_TFTP
        int i, rc = 0;
 
@@ -174,24 +182,32 @@ static inline void store_block(int block, uchar *src, unsigned len)
                /* start address in flash? */
                if (flash_info[i].flash_id == FLASH_UNKNOWN)
                        continue;
-               if (load_addr + offset >= flash_info[i].start[0]) {
+               if (store_addr >= flash_info[i].start[0]) {
                        rc = 1;
                        break;
                }
        }
 
        if (rc) { /* Flash is destination for this packet */
-               rc = flash_write((char *)src, (ulong)(load_addr+offset), len);
+               rc = flash_write((char *)src, store_addr, len);
                if (rc) {
                        flash_perror(rc);
-                       net_set_state(NETLOOP_FAIL);
-                       return;
+                       return rc;
                }
        } else
 #endif /* CONFIG_SYS_DIRECT_FLASH_TFTP */
        {
-               void *ptr = map_sysmem(load_addr + offset, len);
-
+               void *ptr;
+
+#ifdef CONFIG_LMB
+               if (store_addr < tftp_load_addr ||
+                   store_addr + len > tftp_load_addr + tftp_load_size) {
+                       puts("\nTFTP error: ");
+                       puts("trying to overwrite reserved memory...\n");
+                       return -1;
+               }
+#endif
+               ptr = map_sysmem(store_addr, len);
                memcpy(ptr, src, len);
                unmap_sysmem(ptr);
        }
@@ -202,6 +218,8 @@ static inline void store_block(int block, uchar *src, unsigned len)
 
        if (net_boot_file_size < newsize)
                net_boot_file_size = newsize;
+
+       return 0;
 }
 
 /* Clear our state ready for a new transfer */
@@ -602,10 +620,14 @@ static void tftp_handler(uchar *pkt, unsigned dest, struct in_addr sip,
                }
 
                tftp_prev_block = tftp_cur_block;
-               timeout_count_max = TIMEOUT_COUNT;
+               timeout_count_max = tftp_timeout_count_max;
                net_set_timeout_handler(timeout_ms, tftp_timeout_handler);
 
-               store_block(tftp_cur_block - 1, pkt + 2, len);
+               if (store_block(tftp_cur_block - 1, pkt + 2, len)) {
+                       eth_halt();
+                       net_set_state(NETLOOP_FAIL);
+                       break;
+               }
 
                /*
                 *      Acknowledge the block just received, which will prompt
@@ -694,56 +716,77 @@ static void tftp_timeout_handler(void)
        }
 }
 
+/* Initialize tftp_load_addr and tftp_load_size from load_addr and lmb */
+static int tftp_init_load_addr(void)
+{
+#ifdef CONFIG_LMB
+       struct lmb lmb;
+       phys_size_t max_size;
+
+       lmb_init_and_reserve(&lmb, gd->bd->bi_dram[0].start,
+                            gd->bd->bi_dram[0].size, (void *)gd->fdt_blob);
+
+       max_size = lmb_get_unreserved_size(&lmb, load_addr);
+       if (!max_size)
+               return -1;
+
+       tftp_load_size = max_size;
+#endif
+       tftp_load_addr = load_addr;
+       return 0;
+}
 
 void tftp_start(enum proto_t protocol)
 {
+#if CONFIG_NET_TFTP_VARS
        char *ep;             /* Environment pointer */
 
        /*
         * Allow the user to choose TFTP blocksize and timeout.
         * TFTP protocol has a minimal timeout of 1 second.
         */
-       ep = getenv("tftpblocksize");
+
+       ep = env_get("tftpblocksize");
        if (ep != NULL)
                tftp_block_size_option = simple_strtol(ep, NULL, 10);
 
-       ep = getenv("tftptimeout");
+       ep = env_get("tftptimeout");
        if (ep != NULL)
                timeout_ms = simple_strtol(ep, NULL, 10);
 
-       if (timeout_ms < 10) {
-               printf("TFTP timeout (%ld ms) too low, set min = 10 ms\n",
+       if (timeout_ms < 1000) {
+               printf("TFTP timeout (%ld ms) too low, set min = 1000 ms\n",
                       timeout_ms);
-               timeout_ms = 10;
+               timeout_ms = 1000;
        }
 
+       ep = env_get("tftptimeoutcountmax");
+       if (ep != NULL)
+               tftp_timeout_count_max = simple_strtol(ep, NULL, 10);
+
+       if (tftp_timeout_count_max < 0) {
+               printf("TFTP timeout count max (%d ms) negative, set to 0\n",
+                      tftp_timeout_count_max);
+               tftp_timeout_count_max = 0;
+       }
+#endif
+
        debug("TFTP blocksize = %i, timeout = %ld ms\n",
              tftp_block_size_option, timeout_ms);
 
        tftp_remote_ip = net_server_ip;
-       if (net_boot_file_name[0] == '\0') {
+       if (!net_parse_bootfile(&tftp_remote_ip, tftp_filename, MAX_LEN)) {
                sprintf(default_filename, "%02X%02X%02X%02X.img",
                        net_ip.s_addr & 0xFF,
                        (net_ip.s_addr >>  8) & 0xFF,
                        (net_ip.s_addr >> 16) & 0xFF,
                        (net_ip.s_addr >> 24) & 0xFF);
 
-               strncpy(tftp_filename, default_filename, MAX_LEN);
-               tftp_filename[MAX_LEN - 1] = 0;
+               strncpy(tftp_filename, default_filename, DEFAULT_NAME_LEN);
+               tftp_filename[DEFAULT_NAME_LEN - 1] = 0;
 
                printf("*** Warning: no boot file name; using '%s'\n",
                       tftp_filename);
-       } else {
-               char *p = strchr(net_boot_file_name, ':');
-
-               if (p == NULL) {
-                       strncpy(tftp_filename, net_boot_file_name, MAX_LEN);
-                       tftp_filename[MAX_LEN - 1] = 0;
-               } else {
-                       tftp_remote_ip = string_to_ip(net_boot_file_name);
-                       strncpy(tftp_filename, p + 1, MAX_LEN);
-                       tftp_filename[MAX_LEN - 1] = 0;
-               }
        }
 
        printf("Using %s device\n", eth_get_name());
@@ -788,9 +831,19 @@ void tftp_start(enum proto_t protocol)
        } else
 #endif
        {
-               printf("Load address: 0x%lx\n", load_addr);
+               if (tftp_init_load_addr()) {
+                       eth_halt();
+                       net_set_state(NETLOOP_FAIL);
+                       puts("\nTFTP error: ");
+                       puts("trying to overwrite reserved memory...\n");
+                       return;
+               }
+               printf("Load address: 0x%lx\n", tftp_load_addr);
                puts("Loading: *\b");
                tftp_state = STATE_SEND_RRQ;
+#ifdef CONFIG_CMD_BOOTEFI
+               efi_set_bootdev("Net", "", tftp_filename);
+#endif
        }
 
        time_start = get_timer(0);
@@ -807,10 +860,10 @@ void tftp_start(enum proto_t protocol)
        tftp_our_port = 1024 + (get_timer(0) % 3072);
 
 #ifdef CONFIG_TFTP_PORT
-       ep = getenv("tftpdstp");
+       ep = env_get("tftpdstp");
        if (ep != NULL)
                tftp_remote_port = simple_strtol(ep, NULL, 10);
-       ep = getenv("tftpsrcp");
+       ep = env_get("tftpsrcp");
        if (ep != NULL)
                tftp_our_port = simple_strtol(ep, NULL, 10);
 #endif
@@ -836,13 +889,19 @@ void tftp_start_server(void)
 {
        tftp_filename[0] = 0;
 
+       if (tftp_init_load_addr()) {
+               eth_halt();
+               net_set_state(NETLOOP_FAIL);
+               puts("\nTFTP error: trying to overwrite reserved memory...\n");
+               return;
+       }
        printf("Using %s device\n", eth_get_name());
        printf("Listening for TFTP transfer on %pI4\n", &net_ip);
-       printf("Load address: 0x%lx\n", load_addr);
+       printf("Load address: 0x%lx\n", tftp_load_addr);
 
        puts("Loading: *\b");
 
-       timeout_count_max = TIMEOUT_COUNT;
+       timeout_count_max = tftp_timeout_count_max;
        timeout_count = 0;
        timeout_ms = TIMEOUT;
        net_set_timeout_handler(timeout_ms, tftp_timeout_handler);