From 0b8ec0018f2f4780c87eefc412165e4f38d83341 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=B5=9C=EC=84=B1=EC=A7=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Tue, 5 Jun 2018 09:57:39 +0900 Subject: [PATCH] Add int32 initialization to tflite_run (#1550) This commit adds int32 initialization to tflite_run. -With gather operation, index should be within input coverage. Signed-off-by: SungJin Choi --- tools/tflite_run/src/tflite_run.cc | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/tools/tflite_run/src/tflite_run.cc b/tools/tflite_run/src/tflite_run.cc index 90e792e..96249fd 100644 --- a/tools/tflite_run/src/tflite_run.cc +++ b/tools/tflite_run/src/tflite_run.cc @@ -29,6 +29,7 @@ #include "support/tflite/Session.h" #include "support/tflite/InterpreterSession.h" #include "support/tflite/NNAPISession.h" +#include "util/tensor/IndexIterator.h" #include #include @@ -118,13 +119,30 @@ int main(const int argc, char **argv) for (const auto &o : interpreter->inputs()) { TfLiteTensor *tensor = interpreter->tensor(o); - - const int seed = 1; /* TODO Add an option for seed value */ - RandomGenerator randgen{seed, 0.0f, 0.2f}; - const float *end = reinterpret_cast(tensor->data.raw_const + tensor->bytes); - for (float *ptr = tensor->data.f; ptr < end; ptr++) + if (tensor->type == kTfLiteInt32) + { + // Generate singed 32-bit integer (s32) input + auto tensor_view = nnfw::support::tflite::TensorView::make(*interpreter, o); + + int32_t value = 0; + + nnfw::util::tensor::iterate(tensor_view.shape()) + << [&](const nnfw::util::tensor::Index &ind) { + // TODO Generate random values + // Gather operation: index should be within input coverage. + tensor_view.at(ind) = value; + value++; + }; + } + else { - *ptr = randgen(); + const int seed = 1; /* TODO Add an option for seed value */ + RandomGenerator randgen{seed, 0.0f, 0.2f}; + const float *end = reinterpret_cast(tensor->data.raw_const + tensor->bytes); + for (float *ptr = tensor->data.f; ptr < end; ptr++) + { + *ptr = randgen(); + } } } } -- 2.7.4