Pass error reporter to file copy allocation,
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 15 Mar 2018 22:29:39 +0000 (15:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 22:33:32 +0000 (15:33 -0700)
and avoid loading model from file twice

PiperOrigin-RevId: 189256489

tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/model.h
tensorflow/contrib/lite/model_test.cc

index 21bcff4..cc448b0 100644 (file)
@@ -334,6 +334,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
   return reinterpret_cast<jlong>(error_reporter);
 }
 
+// Verifies whether the model is a flatbuffer file.
+class JNIFlatBufferVerifier : public tflite::TfLiteVerifier {
+ public:
+  bool Verify(const char* data, int length,
+              tflite::ErrorReporter* reporter) override {
+    if (!VerifyModel(data, length)) {
+      reporter->Report("The model is not a valid Flatbuffer file");
+      return false;
+    }
+    return true;
+  }
+};
+
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
     JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
@@ -342,17 +355,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
   if (error_reporter == nullptr) return 0;
   const char* path = env->GetStringUTFChars(model_file, nullptr);
 
-  {
-    tflite::FileCopyAllocation allocation(path, nullptr);
-    if (!VerifyModel(allocation.base(), allocation.bytes())) {
-      throwException(env, kIllegalArgumentException,
-                     "Contents of %s is not a valid flatbuffer model", path);
-      env->ReleaseStringUTFChars(model_file, path);
-      return 0;
-    }
-  }
+  std::unique_ptr<tflite::TfLiteVerifier> verifier;
+  verifier.reset(new JNIFlatBufferVerifier());
 
-  auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter);
+  auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
+      path, verifier.get(), error_reporter);
   if (!model) {
     throwException(env, kIllegalArgumentException,
                    "Contents of %s does not encode a valid TensorFlowLite "
index d6b4e9f..dbe45e5 100644 (file)
@@ -47,6 +47,9 @@ public final class NativeInterpreterWrapperTest {
   private static final String MODEL_WITH_CUSTOM_OP_PATH =
       "tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite";
 
+  private static final String NONEXISTING_MODEL_PATH =
+      "tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin";
+
   @Test
   public void testConstructor() {
     NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
@@ -60,7 +63,18 @@ public final class NativeInterpreterWrapperTest {
       NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
       fail();
     } catch (IllegalArgumentException e) {
-      assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model");
+      assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
+    }
+  }
+
+  @Test
+  public void testConstructorWithNonexistingModel() {
+    try {
+      NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
+      fail();
+    } catch (IllegalArgumentException e) {
+      assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
+      assertThat(e).hasMessageThat().contains("Could not open");
     }
   }
 
index 3cf6bcb..f28d56a 100644 (file)
@@ -32,11 +32,46 @@ namespace tflite {
 
 const char* kEmptyTensorName = "";
 
+// Loads a model from `filename`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
+                                                  bool mmap_file,
+                                                  ErrorReporter* error_reporter,
+                                                  bool use_nnapi) {
+  std::unique_ptr<Allocation> allocation;
+  if (mmap_file) {
+    if (use_nnapi && NNAPIExists())
+      allocation.reset(new NNAPIAllocation(filename, error_reporter));
+    else
+      allocation.reset(new MMAPAllocation(filename, error_reporter));
+  } else {
+    allocation.reset(new FileCopyAllocation(filename, error_reporter));
+  }
+  return allocation;
+}
+
 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
     const char* filename, ErrorReporter* error_reporter) {
   std::unique_ptr<FlatBufferModel> model;
-  model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
-                                  /*use_nnapi=*/true));
+  auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
+                                          error_reporter, /*use_nnapi=*/true);
+  model.reset(new FlatBufferModel(allocation.release(), error_reporter));
+  if (!model->initialized()) model.reset();
+  return model;
+}
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
+    const char* filename, TfLiteVerifier* verifier,
+    ErrorReporter* error_reporter) {
+  std::unique_ptr<FlatBufferModel> model;
+  auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
+                                          error_reporter, /*use_nnapi=*/true);
+  if (verifier &&
+      !verifier->Verify(static_cast<const char*>(allocation->base()),
+                        allocation->bytes(), error_reporter)) {
+    return model;
+  }
+  model.reset(new FlatBufferModel(allocation.release(), error_reporter));
   if (!model->initialized()) model.reset();
   return model;
 }
@@ -44,7 +79,9 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
     const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
   std::unique_ptr<FlatBufferModel> model;
-  model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
+  Allocation* allocation =
+      new MemoryAllocation(buffer, buffer_size, error_reporter);
+  model.reset(new FlatBufferModel(allocation, error_reporter));
   if (!model->initialized()) model.reset();
   return model;
 }
@@ -57,23 +94,6 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
   return model;
 }
 
-FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
-                                 ErrorReporter* error_reporter, bool use_nnapi)
-    : error_reporter_(error_reporter ? error_reporter
-                                     : DefaultErrorReporter()) {
-  if (mmap_file) {
-    if (use_nnapi && NNAPIExists())
-      allocation_ = new NNAPIAllocation(filename, error_reporter);
-    else
-      allocation_ = new MMAPAllocation(filename, error_reporter);
-  } else {
-    allocation_ = new FileCopyAllocation(filename, error_reporter);
-  }
-  if (!allocation_->valid() || !CheckModelIdentifier()) return;
-
-  model_ = ::tflite::GetModel(allocation_->base());
-}
-
 bool FlatBufferModel::CheckModelIdentifier() const {
   if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
     const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
@@ -85,21 +105,21 @@ bool FlatBufferModel::CheckModelIdentifier() const {
   return true;
 }
 
-FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
+FlatBufferModel::FlatBufferModel(const Model* model,
                                  ErrorReporter* error_reporter)
     : error_reporter_(error_reporter ? error_reporter
                                      : DefaultErrorReporter()) {
-  allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
-  if (!allocation_->valid()) return;
-
-  model_ = ::tflite::GetModel(allocation_->base());
+  model_ = model;
 }
 
-FlatBufferModel::FlatBufferModel(const Model* model,
+FlatBufferModel::FlatBufferModel(Allocation* allocation,
                                  ErrorReporter* error_reporter)
     : error_reporter_(error_reporter ? error_reporter
                                      : DefaultErrorReporter()) {
-  model_ = model;
+  allocation_ = allocation;
+  if (!allocation_->valid() || !CheckModelIdentifier()) return;
+
+  model_ = ::tflite::GetModel(allocation_->base());
 }
 
 FlatBufferModel::~FlatBufferModel() { delete allocation_; }
index 8dc1c79..38eea0e 100644 (file)
@@ -41,6 +41,17 @@ limitations under the License.
 
 namespace tflite {
 
+// Abstract interface that verifies whether a given model is legit.
+// It facilitates the use-case to verify and build a model without loading it
+// twice.
+class TfLiteVerifier {
+ public:
+  // Returns true if the model is legit.
+  virtual bool Verify(const char* data, int length,
+                      ErrorReporter* reporter) = 0;
+  virtual ~TfLiteVerifier() {}
+};
+
 // An RAII object that represents a read-only tflite model, copied from disk,
 // or mmapped. This uses flatbuffers as the serialization format.
 class FlatBufferModel {
@@ -50,6 +61,12 @@ class FlatBufferModel {
       const char* filename,
       ErrorReporter* error_reporter = DefaultErrorReporter());
 
+  // Verifies whether the content of the file is legit, then builds a model
+  // based on the file. Returns a nullptr in case of failure.
+  static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
+      const char* filename, TfLiteVerifier* verifier = nullptr,
+      ErrorReporter* error_reporter = DefaultErrorReporter());
+
   // Builds a model based on a pre-loaded flatbuffer. The caller retains
   // ownership of the buffer and should keep it alive until the returned object
   // is destroyed. Returns a nullptr in case of failure.
@@ -82,23 +99,9 @@ class FlatBufferModel {
   bool CheckModelIdentifier() const;
 
  private:
-  // Loads a model from `filename`. If `mmap_file` is true then use mmap,
-  // otherwise make a copy of the model in a buffer.
-  //
-  // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
-  // used.
-  explicit FlatBufferModel(
-      const char* filename, bool mmap_file = true,
-      ErrorReporter* error_reporter = DefaultErrorReporter(),
-      bool use_nnapi = false);
-
-  // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
-  // to remain alive and unchanged until the end of this flatbuffermodel's
-  // lifetime.
-  //
-  // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
-  // used.
-  FlatBufferModel(const char* ptr, size_t num_bytes,
+  // Loads a model from a given allocation. FlatBufferModel will take over the
+  // ownership of `allocation`, and delete it in desctructor.
+  FlatBufferModel(Allocation* allocation,
                   ErrorReporter* error_reporter = DefaultErrorReporter());
 
   // Loads a model from Model flatbuffer. The `model` has to remain alive and
index 66f22fd..ae6c1ec 100644 (file)
@@ -209,6 +209,38 @@ TEST(BasicFlatBufferModel, TestNullModel) {
   ASSERT_EQ(interpreter.get(), nullptr);
 }
 
+// Mocks the verifier by setting the result in ctor.
+class FakeVerifier : public tflite::TfLiteVerifier {
+ public:
+  explicit FakeVerifier(bool result) : result_(result) {}
+  bool Verify(const char* data, int length,
+              tflite::ErrorReporter* reporter) override {
+    return result_;
+  }
+
+ private:
+  bool result_;
+};
+
+TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
+  FakeVerifier verifier(true);
+  ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+      "tensorflow/contrib/lite/testdata/test_model.bin",
+      &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
+  FakeVerifier verifier(false);
+  ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
+      "tensorflow/contrib/lite/testdata/test_model.bin",
+      &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithNullVerifier) {
+  ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+      "tensorflow/contrib/lite/testdata/test_model.bin", nullptr));
+}
+
 struct TestErrorReporter : public ErrorReporter {
   int Report(const char* format, va_list args) override {
     calls++;