Check for very large chunk sizes in WAV decoding
authorPete Warden <petewarden@google.com>
Thu, 15 Mar 2018 21:45:34 +0000 (14:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 21:55:27 +0000 (14:55 -0700)
Change how chunk sizes larger than 2GB are handled, since they're stored as
unsigned int32s, so there are lots of ways for conversions to confuse the
decoding logic. The new behavior is to fail with an error, since such
large WAV files are not common, and are unsupported by many readers.

PiperOrigin-RevId: 189248857

tensorflow/core/lib/wav/wav_io.cc
tensorflow/core/lib/wav/wav_io.h
tensorflow/core/lib/wav/wav_io_test.cc

index 77d3c88..51b9c6c 100644 (file)
@@ -81,13 +81,38 @@ inline float Int16SampleToFloat(int16 data) {
   return data * kMultiplier;
 }
 
+}  // namespace
+
+// Handles moving the data index forward, validating the arguments, and avoiding
+// overflow or underflow.
+Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
+                       int* new_offset) {
+  if (old_offset < 0) {
+    return errors::InvalidArgument("Negative offsets are not allowed: ",
+                                   old_offset);
+  }
+  if (old_offset > max_size) {
+    return errors::InvalidArgument("Initial offset is outside data range: ",
+                                   old_offset);
+  }
+  *new_offset = old_offset + increment;
+  if (*new_offset > max_size) {
+    return errors::InvalidArgument("Data too short when trying to read string");
+  }
+  // See above for the check that the input offset is positive. If it's negative
+  // here then it means that there's been an overflow in the arithmetic.
+  if (*new_offset < 0) {
+    return errors::InvalidArgument("Offset too large, overflowed: ",
+                                   *new_offset);
+  }
+  return Status::OK();
+}
+
 Status ExpectText(const string& data, const string& expected_text,
                   int* offset) {
-  const int new_offset = *offset + expected_text.size();
-  if (new_offset > data.size()) {
-    return errors::InvalidArgument("Data too short when trying to read ",
-                                   expected_text);
-  }
+  int new_offset;
+  TF_RETURN_IF_ERROR(
+      IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
   const string found_text(data.begin() + *offset, data.begin() + new_offset);
   if (found_text != expected_text) {
     return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
@@ -97,40 +122,16 @@ Status ExpectText(const string& data, const string& expected_text,
   return Status::OK();
 }
 
-template <class T>
-Status ReadValue(const string& data, T* value, int* offset) {
-  const int new_offset = *offset + sizeof(T);
-  if (new_offset > data.size()) {
-    return errors::InvalidArgument("Data too short when trying to read value");
-  }
-  if (port::kLittleEndian) {
-    memcpy(value, data.data() + *offset, sizeof(T));
-  } else {
-    *value = 0;
-    const uint8* data_buf =
-        reinterpret_cast<const uint8*>(data.data() + *offset);
-    int shift = 0;
-    for (int i = 0; i < sizeof(T); ++i, shift += 8) {
-      *value = *value | (data_buf[i] << shift);
-    }
-  }
-  *offset = new_offset;
-  return Status::OK();
-}
-
 Status ReadString(const string& data, int expected_length, string* value,
                   int* offset) {
-  const int new_offset = *offset + expected_length;
-  if (new_offset > data.size()) {
-    return errors::InvalidArgument("Data too short when trying to read string");
-  }
+  int new_offset;
+  TF_RETURN_IF_ERROR(
+      IncrementOffset(*offset, expected_length, data.size(), &new_offset));
   *value = string(data.begin() + *offset, data.begin() + new_offset);
   *offset = new_offset;
   return Status::OK();
 }
 
-}  // namespace
-
 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
                              size_t num_channels, size_t num_frames,
                              string* wav_string) {
@@ -272,6 +273,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
     TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
     uint32 chunk_size;
     TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
+    if (chunk_size > std::numeric_limits<int32>::max()) {
+      return errors::InvalidArgument(
+          "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
+          " bytes, but the limit is ", std::numeric_limits<int32>::max());
+    }
     if (chunk_id == kDataChunkId) {
       if (was_data_found) {
         return errors::InvalidArgument("More than one data chunk found in WAV");
index adca0ee..f004524 100644 (file)
@@ -21,6 +21,9 @@ limitations under the License.
 #include <string>
 #include <vector>
 
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -55,6 +58,36 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
                                     uint32* sample_count, uint16* channel_count,
                                     uint32* sample_rate);
 
+// Everything below here is only exposed publicly for testing purposes.
+
+// Handles moving the data index forward, validating the arguments, and avoiding
+// overflow or underflow.
+Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
+                       int* new_offset);
+
+// This function is only exposed in the header for testing purposes, as a
+// template that needs to be instantiated. Reads a typed numeric value from a
+// stream of data.
+template <class T>
+Status ReadValue(const string& data, T* value, int* offset) {
+  int new_offset;
+  TF_RETURN_IF_ERROR(
+      IncrementOffset(*offset, sizeof(T), data.size(), &new_offset));
+  if (port::kLittleEndian) {
+    memcpy(value, data.data() + *offset, sizeof(T));
+  } else {
+    *value = 0;
+    const uint8* data_buf =
+        reinterpret_cast<const uint8*>(data.data() + *offset);
+    int shift = 0;
+    for (int i = 0; i < sizeof(T); ++i, shift += 8) {
+      *value = *value | (data_buf[i] << shift);
+    }
+  }
+  *offset = new_offset;
+  return Status::OK();
+}
+
 }  // namespace wav
 }  // namespace tensorflow
 
index 40ddd94..d8a83fc 100644 (file)
@@ -25,6 +25,12 @@ limitations under the License.
 namespace tensorflow {
 namespace wav {
 
+// These are defined in wav_io.cc, and the signatures are here so we don't have
+// to expose them in the public header.
+Status ExpectText(const string& data, const string& expected_text, int* offset);
+Status ReadString(const string& data, int expected_length, string* value,
+                  int* offset);
+
 TEST(WavIO, BadArguments) {
   float audio[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
   string result;
@@ -155,5 +161,318 @@ TEST(WavIO, BasicStereo) {
   EXPECT_EQ(expected, result);
 }
 
+// Test how chunk sizes larger than 2GB are handled, since they're stored as
+// unsigned int32s, so there are lots of ways for conversions to confuse the
+// decoding logic. The expected behavior is to fail with an error, since such
+// large WAV files are not common, and are unsupported by many readers.
+// See b/72655902.
+TEST(WavIO, ChunkSizeOverflow) {
+  std::vector<uint8> wav_data = {
+      'R', 'I', 'F', 'F',      // ChunkID
+      60, 0, 0, 0,             // ChunkSize: 36 + SubChunk2Size
+      'W', 'A', 'V', 'E',      // Format
+      'f', 'm', 't', ' ',      // Subchunk1ID
+      16, 0, 0, 0,             // Subchunk1Size
+      1, 0,                    // AudioFormat: 1=PCM
+      1, 0,                    // NumChannels
+      0x44, 0xac, 0, 0,        // SampleRate: 44100
+      0x88, 0x58, 0x1, 0,      // BytesPerSecond: SampleRate * NumChannels *
+                               //                 BitsPerSample/8
+      2, 0,                    // BytesPerSample: NumChannels * BitsPerSample/8
+      16, 0,                   // BitsPerSample
+      'd', 'a', 't', 'a',      // Subchunk2ID
+      8, 0, 0, 0,              // Subchunk2Size: NumSamples * NumChannels *
+                               //                BitsPerSample/8
+      0, 0,                    // Sample 1: 0
+      0xff, 0x7f,              // Sample 2: 32767 (saturated)
+      0, 0,                    // Sample 3: 0
+      0x00, 0x80,              // Sample 4: -32768 (saturated)
+      'f', 'o', 'o', 'o',      // Subchunk2ID
+      0xff, 0xff, 0xff, 0xf8,  // Chunk size that could cause an infinite loop.
+      0, 0,                    // Sample 1: 0
+      0xff, 0x7f,              // Sample 2: 32767 (saturated)
+      0, 0,                    // Sample 3: 0
+      0x00, 0x80,              // Sample 4: -32768 (saturated)
+  };
+  string wav_data_string(wav_data.begin(), wav_data.end());
+  std::vector<float> decoded_audio;
+  uint32 decoded_sample_count;
+  uint16 decoded_channel_count;
+  uint32 decoded_sample_rate;
+  Status decode_status = DecodeLin16WaveAsFloatVector(
+      wav_data_string, &decoded_audio, &decoded_sample_count,
+      &decoded_channel_count, &decoded_sample_rate);
+  EXPECT_FALSE(decode_status.ok());
+  EXPECT_TRUE(StringPiece(decode_status.error_message()).contains("too large"))
+      << decode_status.error_message();
+}
+
+TEST(WavIO, IncrementOffset) {
+  int new_offset = -1;
+  TF_EXPECT_OK(IncrementOffset(0, 10, 20, &new_offset));
+  EXPECT_EQ(10, new_offset);
+
+  new_offset = -1;
+  TF_EXPECT_OK(IncrementOffset(10, 4, 20, &new_offset));
+  EXPECT_EQ(14, new_offset);
+
+  new_offset = -1;
+  TF_EXPECT_OK(IncrementOffset(99, 1, 100, &new_offset));
+  EXPECT_EQ(100, new_offset);
+
+  new_offset = -1;
+  EXPECT_FALSE(IncrementOffset(-1, 1, 100, &new_offset).ok());
+
+  new_offset = -1;
+  EXPECT_FALSE(IncrementOffset(0, -1, 100, &new_offset).ok());
+
+  new_offset = -1;
+  EXPECT_FALSE(IncrementOffset(std::numeric_limits<int>::max(), 1,
+                               std::numeric_limits<int>::max(), &new_offset)
+                   .ok());
+
+  new_offset = -1;
+  EXPECT_FALSE(IncrementOffset(101, 1, 100, &new_offset).ok());
+}
+
+TEST(WavIO, ExpectText) {
+  std::vector<uint8> test_data = {
+      'E', 'x', 'p', 'e', 'c', 't', 'e', 'd',
+  };
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  TF_EXPECT_OK(ExpectText(test_string, "Expected", &offset));
+  EXPECT_EQ(8, offset);
+
+  offset = 0;
+  Status expect_status = ExpectText(test_string, "Unexpected", &offset);
+  EXPECT_FALSE(expect_status.ok());
+
+  offset = 0;
+  TF_EXPECT_OK(ExpectText(test_string, "Exp", &offset));
+  EXPECT_EQ(3, offset);
+  TF_EXPECT_OK(ExpectText(test_string, "ected", &offset));
+  EXPECT_EQ(8, offset);
+  expect_status = ExpectText(test_string, "foo", &offset);
+  EXPECT_FALSE(expect_status.ok());
+}
+
+TEST(WavIO, ReadString) {
+  std::vector<uint8> test_data = {
+      'E', 'x', 'p', 'e', 'c', 't', 'e', 'd',
+  };
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  string read_value;
+  TF_EXPECT_OK(ReadString(test_string, 2, &read_value, &offset));
+  EXPECT_EQ("Ex", read_value);
+  EXPECT_EQ(2, offset);
+
+  TF_EXPECT_OK(ReadString(test_string, 6, &read_value, &offset));
+  EXPECT_EQ("pected", read_value);
+  EXPECT_EQ(8, offset);
+
+  Status read_status = ReadString(test_string, 3, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
+TEST(WavIO, ReadValueInt8) {
+  std::vector<uint8> test_data = {0x00, 0x05, 0xff, 0x80};
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  int8 read_value;
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(0, read_value);
+  EXPECT_EQ(1, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(5, read_value);
+  EXPECT_EQ(2, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(-1, read_value);
+  EXPECT_EQ(3, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(-128, read_value);
+  EXPECT_EQ(4, offset);
+
+  Status read_status = ReadValue(test_string, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
+TEST(WavIO, ReadValueUInt8) {
+  std::vector<uint8> test_data = {0x00, 0x05, 0xff, 0x80};
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  uint8 read_value;
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(0, read_value);
+  EXPECT_EQ(1, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(5, read_value);
+  EXPECT_EQ(2, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(255, read_value);
+  EXPECT_EQ(3, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(128, read_value);
+  EXPECT_EQ(4, offset);
+
+  Status read_status = ReadValue(test_string, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
+TEST(WavIO, ReadValueInt16) {
+  std::vector<uint8> test_data = {
+      0x00, 0x00,  // 0
+      0xff, 0x00,  // 255
+      0x00, 0x01,  // 256
+      0xff, 0xff,  // -1
+      0x00, 0x80,  // -32768
+  };
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  int16 read_value;
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(0, read_value);
+  EXPECT_EQ(2, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(255, read_value);
+  EXPECT_EQ(4, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(256, read_value);
+  EXPECT_EQ(6, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(-1, read_value);
+  EXPECT_EQ(8, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(-32768, read_value);
+  EXPECT_EQ(10, offset);
+
+  Status read_status = ReadValue(test_string, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
+TEST(WavIO, ReadValueUInt16) {
+  std::vector<uint8> test_data = {
+      0x00, 0x00,  // 0
+      0xff, 0x00,  // 255
+      0x00, 0x01,  // 256
+      0xff, 0xff,  // 65535
+      0x00, 0x80,  // 32768
+  };
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  uint16 read_value;
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(0, read_value);
+  EXPECT_EQ(2, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(255, read_value);
+  EXPECT_EQ(4, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(256, read_value);
+  EXPECT_EQ(6, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(65535, read_value);
+  EXPECT_EQ(8, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(32768, read_value);
+  EXPECT_EQ(10, offset);
+
+  Status read_status = ReadValue(test_string, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
+TEST(WavIO, ReadValueInt32) {
+  std::vector<uint8> test_data = {
+      0x00, 0x00, 0x00, 0x00,  // 0
+      0xff, 0x00, 0x00, 0x00,  // 255
+      0x00, 0xff, 0x00, 0x00,  // 65280
+      0x00, 0x00, 0xff, 0x00,  // 16,711,680
+      0xff, 0xff, 0xff, 0xff,  // -1
+  };
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  int32 read_value;
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(0, read_value);
+  EXPECT_EQ(4, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(255, read_value);
+  EXPECT_EQ(8, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(65280, read_value);
+  EXPECT_EQ(12, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(16711680, read_value);
+  EXPECT_EQ(16, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(-1, read_value);
+  EXPECT_EQ(20, offset);
+
+  Status read_status = ReadValue(test_string, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
+TEST(WavIO, ReadValueUInt32) {
+  std::vector<uint8> test_data = {
+      0x00, 0x00, 0x00, 0x00,  // 0
+      0xff, 0x00, 0x00, 0x00,  // 255
+      0x00, 0xff, 0x00, 0x00,  // 65280
+      0x00, 0x00, 0xff, 0x00,  // 16,711,680
+      0xff, 0xff, 0xff, 0xff,  // 4,294,967,295
+  };
+  string test_string(test_data.begin(), test_data.end());
+
+  int offset = 0;
+  uint32 read_value;
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(0, read_value);
+  EXPECT_EQ(4, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(255, read_value);
+  EXPECT_EQ(8, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(65280, read_value);
+  EXPECT_EQ(12, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(16711680, read_value);
+  EXPECT_EQ(16, offset);
+
+  TF_EXPECT_OK(ReadValue(test_string, &read_value, &offset));
+  EXPECT_EQ(4294967295, read_value);
+  EXPECT_EQ(20, offset);
+
+  Status read_status = ReadValue(test_string, &read_value, &offset);
+  EXPECT_FALSE(read_status.ok());
+}
+
 }  // namespace wav
 }  // namespace tensorflow