return reinterpret_cast<jlong>(error_reporter);
}
+// Verifies whether the model is a flatbuffer file.
+class JNIFlatBufferVerifier : public tflite::TfLiteVerifier {
+ public:
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override {
+ if (!VerifyModel(data, length)) {
+ reporter->Report("The model is not a valid Flatbuffer file");
+ return false;
+ }
+ return true;
+ }
+};
+
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
if (error_reporter == nullptr) return 0;
const char* path = env->GetStringUTFChars(model_file, nullptr);
- {
- tflite::FileCopyAllocation allocation(path, nullptr);
- if (!VerifyModel(allocation.base(), allocation.bytes())) {
- throwException(env, kIllegalArgumentException,
- "Contents of %s is not a valid flatbuffer model", path);
- env->ReleaseStringUTFChars(model_file, path);
- return 0;
- }
- }
+ std::unique_ptr<tflite::TfLiteVerifier> verifier;
+ verifier.reset(new JNIFlatBufferVerifier());
- auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter);
+ auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
+ path, verifier.get(), error_reporter);
if (!model) {
throwException(env, kIllegalArgumentException,
"Contents of %s does not encode a valid TensorFlowLite "
private static final String MODEL_WITH_CUSTOM_OP_PATH =
"tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite";
+ private static final String NONEXISTING_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin";
+
@Test
public void testConstructor() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model");
+ assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
+ }
+ }
+
+ @Test
+ public void testConstructorWithNonexistingModel() {
+ try {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
+ assertThat(e).hasMessageThat().contains("Could not open");
}
}
const char* kEmptyTensorName = "";
+// Loads a model from `filename`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
+ bool mmap_file,
+ ErrorReporter* error_reporter,
+ bool use_nnapi) {
+ std::unique_ptr<Allocation> allocation;
+ if (mmap_file) {
+ if (use_nnapi && NNAPIExists())
+ allocation.reset(new NNAPIAllocation(filename, error_reporter));
+ else
+ allocation.reset(new MMAPAllocation(filename, error_reporter));
+ } else {
+ allocation.reset(new FileCopyAllocation(filename, error_reporter));
+ }
+ return allocation;
+}
+
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
- model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
- /*use_nnapi=*/true));
+ auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
+ error_reporter, /*use_nnapi=*/true);
+ model.reset(new FlatBufferModel(allocation.release(), error_reporter));
+ if (!model->initialized()) model.reset();
+ return model;
+}
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
+ const char* filename, TfLiteVerifier* verifier,
+ ErrorReporter* error_reporter) {
+ std::unique_ptr<FlatBufferModel> model;
+ auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
+ error_reporter, /*use_nnapi=*/true);
+ if (verifier &&
+ !verifier->Verify(static_cast<const char*>(allocation->base()),
+ allocation->bytes(), error_reporter)) {
+ return model;
+ }
+ model.reset(new FlatBufferModel(allocation.release(), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
- model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
+ Allocation* allocation =
+ new MemoryAllocation(buffer, buffer_size, error_reporter);
+ model.reset(new FlatBufferModel(allocation, error_reporter));
if (!model->initialized()) model.reset();
return model;
}
return model;
}
-FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
- ErrorReporter* error_reporter, bool use_nnapi)
- : error_reporter_(error_reporter ? error_reporter
- : DefaultErrorReporter()) {
- if (mmap_file) {
- if (use_nnapi && NNAPIExists())
- allocation_ = new NNAPIAllocation(filename, error_reporter);
- else
- allocation_ = new MMAPAllocation(filename, error_reporter);
- } else {
- allocation_ = new FileCopyAllocation(filename, error_reporter);
- }
- if (!allocation_->valid() || !CheckModelIdentifier()) return;
-
- model_ = ::tflite::GetModel(allocation_->base());
-}
-
bool FlatBufferModel::CheckModelIdentifier() const {
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
return true;
}
-FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
+FlatBufferModel::FlatBufferModel(const Model* model,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
- allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
- if (!allocation_->valid()) return;
-
- model_ = ::tflite::GetModel(allocation_->base());
+ model_ = model;
}
-FlatBufferModel::FlatBufferModel(const Model* model,
+FlatBufferModel::FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
- model_ = model;
+ allocation_ = allocation;
+ if (!allocation_->valid() || !CheckModelIdentifier()) return;
+
+ model_ = ::tflite::GetModel(allocation_->base());
}
FlatBufferModel::~FlatBufferModel() { delete allocation_; }
namespace tflite {
+// Abstract interface that verifies whether a given model is legit.
+// It facilitates the use-case to verify and build a model without loading it
+// twice.
+class TfLiteVerifier {
+ public:
+ // Returns true if the model is legit.
+ virtual bool Verify(const char* data, int length,
+ ErrorReporter* reporter) = 0;
+ virtual ~TfLiteVerifier() {}
+};
+
// An RAII object that represents a read-only tflite model, copied from disk,
// or mmapped. This uses flatbuffers as the serialization format.
class FlatBufferModel {
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ // Verifies whether the content of the file is legit, then builds a model
+ // based on the file. Returns a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
+ const char* filename, TfLiteVerifier* verifier = nullptr,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
// Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Returns a nullptr in case of failure.
bool CheckModelIdentifier() const;
private:
- // Loads a model from `filename`. If `mmap_file` is true then use mmap,
- // otherwise make a copy of the model in a buffer.
- //
- // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
- // used.
- explicit FlatBufferModel(
- const char* filename, bool mmap_file = true,
- ErrorReporter* error_reporter = DefaultErrorReporter(),
- bool use_nnapi = false);
-
- // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
- // to remain alive and unchanged until the end of this flatbuffermodel's
- // lifetime.
- //
- // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
- // used.
- FlatBufferModel(const char* ptr, size_t num_bytes,
+ // Loads a model from a given allocation. FlatBufferModel will take over the
+ // ownership of `allocation`, and delete it in desctructor.
+ FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Loads a model from Model flatbuffer. The `model` has to remain alive and
ASSERT_EQ(interpreter.get(), nullptr);
}
+// Mocks the verifier by setting the result in ctor.
+class FakeVerifier : public tflite::TfLiteVerifier {
+ public:
+ explicit FakeVerifier(bool result) : result_(result) {}
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override {
+ return result_;
+ }
+
+ private:
+ bool result_;
+};
+
+TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
+ FakeVerifier verifier(true);
+ ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin",
+ &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
+ FakeVerifier verifier(false);
+ ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin",
+ &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithNullVerifier) {
+ ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin", nullptr));
+}
+
struct TestErrorReporter : public ErrorReporter {
int Report(const char* format, va_list args) override {
calls++;