From c04ed020ac2c9440b9d20f4c49b9f0229b07c236 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 2 Sep 2019 13:16:38 +0900 Subject: [PATCH] [exo-tflite] Change the name of input of TFLRelu (#7081) * [exo-tflite] Change the name of input of TFLRelu Input name follows TensorFlow input naming. (input -> features) Signed-off-by: Hyun Sik Yoon * modify input -> features --- compiler/exo-tflite/src/Dialect/IR/TFLNodes.h | 4 ++-- compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp | 2 +- .../exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp | 2 +- compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp | 2 +- compiler/exo-tflite/src/OperationExporter.cpp | 2 +- compiler/exo-tflite/src/TFLFormattedGraph.cpp | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index de3b0a5..31184a9 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -164,8 +164,8 @@ public: TFLRelu() = default; public: - loco::Node *input(void) const { return at(0)->node(); } - void input(loco::Node *node) { at(0)->node(node); } + loco::Node *features(void) const { return at(0)->node(); } + void features(loco::Node *node) { at(0)->node(node); } }; // TODO TFLRelu6 diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp index 90c98f3..6f0abc8 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp @@ -44,7 +44,7 @@ TEST(TFLReluTest, constructor) ASSERT_EQ(relu_node.dialect(), locoex::TFLDialect::get()); ASSERT_EQ(relu_node.opcode(), locoex::TFLOpcode::RELU); - ASSERT_EQ(relu_node.input(), nullptr); + ASSERT_EQ(relu_node.features(), nullptr); } // TODO TFLRelu6 diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp index de2e966..3bb24f9 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp @@ -33,7 +33,7 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu) auto pull_node = g->nodes()->create(); auto tfl_node = g->nodes()->create(); - tfl_node->input(pull_node); + tfl_node->features(pull_node); auto push_node = g->nodes()->create(); push_node->from(tfl_node); diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp index 180dbe2..0190eeb 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp @@ -32,7 +32,7 @@ TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu) auto pull_node = g->nodes()->create(); auto tfl_node = g->nodes()->create(); - tfl_node->input(pull_node); + tfl_node->features(pull_node); auto push_node = g->nodes()->create(); push_node->from(tfl_node); diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index 91fa2f8..a569052 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -115,7 +115,7 @@ private: void OperationExporter::visit(locoex::TFLRelu *node) { uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU); - std::vector inputs_vec{get_tensor_index(node->input())}; + std::vector inputs_vec{get_tensor_index(node->features())}; std::vector outputs_vec{get_tensor_index(static_cast(node))}; auto inputs = builder.CreateVector(inputs_vec); auto outputs = builder.CreateVector(outputs_vec); diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 0f29aa3..cb0a6f3 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -116,7 +116,7 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSummary &s) const { s.opname("TFL.RELU"); - s.args().append("input", tbl()->lookup(node->input())); + s.args().append("input", tbl()->lookup(node->features())); s.state(locop::NodeSummary::State::Complete); return true; } -- 2.7.4