Make sure tensor size match before inspecting their content.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 20:27:07 +0000 (13:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 20:29:38 +0000 (13:29 -0700)
PiperOrigin-RevId: 190823557

tensorflow/contrib/lite/testing/tflite_driver.cc

index 613223f..c399f4f 100644 (file)
@@ -56,12 +56,16 @@ void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
 
 class TfLiteDriver::Expectation {
  public:
-  Expectation() { data_.raw = nullptr; }
+  Expectation() {
+    data_.raw = nullptr;
+    num_elements_ = 0;
+  }
   ~Expectation() { delete[] data_.raw; }
   template <typename T>
   void SetData(const string& csv_values) {
     const auto& values = testing::Split<T>(csv_values, ",");
-    data_.raw = new char[values.size() * sizeof(T)];
+    num_elements_ = values.size();
+    data_.raw = new char[num_elements_ * sizeof(T)];
     SetTensorData(values, &data_);
   }
 
@@ -88,7 +92,13 @@ class TfLiteDriver::Expectation {
     constexpr double kRelativeThreshold = 1e-2f;
     constexpr double kAbsoluteThreshold = 1e-4f;
 
-    int tensor_size = tensor.bytes / sizeof(T);
+    size_t tensor_size = tensor.bytes / sizeof(T);
+
+    if (tensor_size != num_elements_) {
+      std::cerr << "Expected a tensor with " << num_elements_
+                << " elements, got " << tensor_size << std::endl;
+      return false;
+    }
 
     bool good_output = true;
     for (int i = 0; i < tensor_size; ++i) {
@@ -115,6 +125,7 @@ class TfLiteDriver::Expectation {
   }
 
   TfLitePtrUnion data_;
+  size_t num_elements_;
 };
 
 TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}