From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Tue, 9 Jul 2019 00:30:15 +0000 (+0900) Subject: [exo.tflite] Properly propagate type information (#4144) X-Git-Tag: nncc_backup~148 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=69cb136492e1a0476d00ac3171f0d7d9bbbe535b;p=platform%2Fcore%2Fml%2Fnnfw.git [exo.tflite] Properly propagate type information (#4144) The current implementation of type inference fails to properly propagate type information. Signed-off-by: Jonghyun Park --- diff --git a/contrib/exo-tflite/src/TypeInference.cpp b/contrib/exo-tflite/src/TypeInference.cpp index f5bd4ab..b4e693f 100644 --- a/contrib/exo-tflite/src/TypeInference.cpp +++ b/contrib/exo-tflite/src/TypeInference.cpp @@ -170,6 +170,7 @@ public: { \ auto t = getOpResultType(node, _ctx); \ node->annot(stdex::make_unique(t)); \ + _ctx._node_to_type[node] = t; \ } NODE(ConstGen) NODE(Pull) diff --git a/contrib/exo-tflite/src/TypeInference.test.cpp b/contrib/exo-tflite/src/TypeInference.test.cpp new file mode 100644 index 0000000..3a922e8 --- /dev/null +++ b/contrib/exo-tflite/src/TypeInference.test.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TypeInference.h" + +#include +#include + +#include + +using stdex::make_unique; + +namespace +{ + +class Sequential +{ +public: + loco::Pull *addPullLayer(const loco::DataType &dtype = loco::DataType::FLOAT32) + { + loco::Pull *pull = _graph.nodes()->create(); + + auto graph_input = _graph.inputs()->create(); + graph_input->name("graph_input"); + graph_input->node(pull); + + pull->dtype(dtype); + setSampleShape(pull); + + return last(pull); + } + + loco::ReLU *addReLULayer(void) + { + loco::ReLU *relu = _graph.nodes()->create(); + + relu->input(_last); + + return last(relu); + } + + loco::Push *addPushLayer(void) + { + loco::Push *push = _graph.nodes()->create(); + + auto graph_output = _graph.outputs()->create(); + graph_output->name("graph_output"); + graph_output->node(push); + + push->from(_last); + + return last(push); + } + + loco::Graph *graph() { return &_graph; } + +private: + template uint32_t setSampleShape(T *op) + { + const uint32_t n = 1; + const uint32_t h = 100; + const uint32_t w = 100; + const uint32_t c = 3; + op->rank(4); + op->dim(0).set(n); + op->dim(1).set(c); + op->dim(2).set(h); + op->dim(3).set(w); + return n * h * w * c; + } + + template T *last(T *node) + { + _last = node; + return node; + } + +private: + loco::Graph _graph; + loco::Node *_last; +}; + +struct TypeInferenceTest : public Sequential, public ::testing::Test +{ + virtual ~TypeInferenceTest() = default; +}; + +} // namespace + +// TypeInference SHOULD PROPAGATE type information properly +TEST_F(TypeInferenceTest, Regression_0000) +{ + auto pull = addPullLayer(loco::DataType::S8); + auto relu = addReLULayer(); + auto push = addPushLayer(); + + TypeInference::run(graph()); + + ASSERT_EQ(TypeInference::get(relu), tflite::TensorType_INT8); + ASSERT_EQ(TypeInference::get(push), tflite::TensorType_INT8); +}