Initialize and reset for bool type in Diff.cpp (#3777)
author오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 3 Dec 2018 10:29:07 +0000 (19:29 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Mon, 3 Dec 2018 10:29:07 +0000 (19:29 +0900)
Define initializer, reseter and comparator for bool type tensor

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
libs/support/tflite/src/Diff.cpp

index e875571..8d7082e 100644 (file)
@@ -228,6 +228,14 @@ bool TfLiteInterpMatchApp::run(::tflite::Interpreter &interp, ::tflite::Interpre
     return compareSingleTensorView(expected, obtained, id);
   };
 
+  comparators[kTfLiteBool] = [this](int id, ::tflite::Interpreter &interp,
+                                    ::tflite::Interpreter &nnapi) {
+    const auto expected = nnfw::support::tflite::TensorView<bool>::make(interp, id);
+    const auto obtained = nnfw::support::tflite::TensorView<bool>::make(nnapi, id);
+
+    return compareSingleTensorView(expected, obtained, id);
+  };
+
   for (const auto &id : interp.outputs())
   {
     assert(interp.tensor(id)->type == nnapi.tensor(id)->type);
@@ -431,6 +439,58 @@ int RandomTestRunner::run(const nnfw::support::tflite::interp::Builder &builder)
            };
   };
 
+  initializers[kTfLiteBool] = [&](int id, Interpreter *tfl_interp, Interpreter *nnapi) {
+    assert(tfl_interp->tensor(id)->type == kTfLiteBool);
+    assert(nnapi->tensor(id)->type == kTfLiteBool);
+
+    auto tfl_interp_view = nnfw::support::tflite::TensorView<bool>::make(*tfl_interp, id);
+    auto nnapi_view = nnfw::support::tflite::TensorView<bool>::make(*nnapi, id);
+
+    assert(tfl_interp_view.shape() == nnapi_view.shape());
+
+    auto fp = static_cast<bool (RandomGenerator::*)(const ::nnfw::util::tensor::Shape &,
+                                                    const ::nnfw::util::tensor::Index &)>(
+        &RandomGenerator::generate<bool>);
+    const nnfw::util::tensor::Object<bool> data(tfl_interp_view.shape(),
+                                                std::bind(fp, _randgen, _1, _2));
+
+    assert(tfl_interp_view.shape() == data.shape());
+
+    nnfw::util::tensor::iterate(tfl_interp_view.shape())
+        << [&](const nnfw::util::tensor::Index &ind) {
+             const auto value = data.at(ind);
+
+             tfl_interp_view.at(ind) = value;
+             nnapi_view.at(ind) = value;
+           };
+  };
+
+  reseters[kTfLiteBool] = [&](int id, Interpreter *tfl_interp, Interpreter *nnapi) {
+    assert(tfl_interp->tensor(id)->type == kTfLiteBool);
+    assert(nnapi->tensor(id)->type == kTfLiteBool);
+
+    auto tfl_interp_view = nnfw::support::tflite::TensorView<bool>::make(*tfl_interp, id);
+    auto nnapi_view = nnfw::support::tflite::TensorView<bool>::make(*nnapi, id);
+
+    assert(tfl_interp_view.shape() == nnapi_view.shape());
+
+    auto fp = static_cast<bool (RandomGenerator::*)(const ::nnfw::util::tensor::Shape &,
+                                                    const ::nnfw::util::tensor::Index &)>(
+        &RandomGenerator::generate<bool>);
+    const nnfw::util::tensor::Object<bool> data(tfl_interp_view.shape(),
+                                                std::bind(fp, _randgen, _1, _2));
+
+    assert(tfl_interp_view.shape() == data.shape());
+
+    bool value = false;
+
+    nnfw::util::tensor::iterate(tfl_interp_view.shape())
+        << [&](const nnfw::util::tensor::Index &ind) {
+             tfl_interp_view.at(ind) = value;
+             nnapi_view.at(ind) = value;
+           };
+  };
+
   // Fill IFM with random numbers
   for (const auto id : tfl_interp->inputs())
   {