From 8be015a581758b9ecd7840a1224bd4667b369a68 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 16 May 2019 10:54:15 +0900 Subject: [PATCH] [moco] import Const attributes and value data (#3487) This will enable import for Const attributes and data where there are same number of const value data. Add test code for this case. Signed-off-by: SaeHie Park --- contrib/moco/lib/frontend/tf/src/Op/Const.cpp | 75 ++++++++++++++- contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp | 106 +++++++++++++++++++++ 2 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp diff --git a/contrib/moco/lib/frontend/tf/src/Op/Const.cpp b/contrib/moco/lib/frontend/tf/src/Op/Const.cpp index d98e0f9..e460529 100644 --- a/contrib/moco/lib/frontend/tf/src/Op/Const.cpp +++ b/contrib/moco/lib/frontend/tf/src/Op/Const.cpp @@ -25,6 +25,31 @@ #include #include +namespace +{ + +void read_value_float32(loco::ConstGen *const_node, int num_elements, + const tensorflow::TensorProto &input_tensor) +{ + const_node->size(num_elements); + + int32_t input_elements = input_tensor.float_val_size(); + + if (input_elements == num_elements) + { + for (int32_t i = 0; i < input_elements; i++) + { + const_node->at(i) = input_tensor.float_val(i); + } + } + else + { + throw std::runtime_error("Error: Invalid Const values"); + } +} + +} // namespace + namespace moco { namespace tf @@ -60,11 +85,55 @@ void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte // Create a "ConstGen" node for Const auto const_node = graph->nodes()->create(); - // TODO fill attributes + // set dtype + auto dtype = as_loco_datatype(get_datatype_attr(node, "dtype")); + const_node->dtype(dtype); + + // import shape and value + const auto &input_tensor = get_tensor_attr(node, "value"); + const auto &input_shape = input_tensor.tensor_shape(); + const auto &input_dims = input_shape.dim(); + assert(input_shape.dim_size() <= 6); + const_node->rank(input_shape.dim_size()); + int index = 0; + bool zero_sized_shape = false; + for (auto &d : input_dims) + { + if (d.size() > std::numeric_limits::max()) + throw std::runtime_error("Shape element overflows"); + if (d.size() == 0) + zero_sized_shape = true; + + const_node->dim(index++) = loco::make_dimension(d.size()); + } + + int num_elements = 1; + if (zero_sized_shape) + { + const_node->rank(0); + num_elements = 0; + } + else + { + for (int d = 0; d < const_node->rank(); d++) + { + num_elements *= const_node->dim(d).value(); + } + } + + switch (dtype) + { + case loco::DataType::FLOAT32: + read_value_float32(const_node, num_elements, input_tensor); + break; + + // TODO support other types + + default: + throw std::runtime_error{"Error: Unsupported data type for " + node.name()}; + } nodes->enroll(node.name(), const_node); - - throw std::runtime_error{"NYI"}; } } // namespace tf diff --git a/contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp b/contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp new file mode 100644 index 0000000..5afff9b --- /dev/null +++ b/contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TestHelper.h" + +#include + +#include + +#include + +#include + +#include +#include + +using namespace moco::tf::test; + +namespace +{ +// Test case for "input_tensor.float_val_size() == num_elements" + +// clang-format off +const char *const_float_01_pbtxtdata = STRING_CONTENT( +node { + name: "const/float" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 3 + } + } + float_val: 1.1 + float_val: 2.2 + float_val: 3.3 + float_val: 4.4 + float_val: 5.5 + float_val: 6.6 + } + } + } +} +); +// clang-format on + +} // namespace + +TEST(TensorFlowFrontend, const_float_01) +{ + moco::tf::Frontend frontend; + moco::tf::ModelSignature signature; + + imemstream mempb(const_float_01_pbtxtdata, std::strlen(const_float_01_pbtxtdata)); + + signature.add_output("const/float"); + + std::unique_ptr graph = + frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text); + + loco::Graph::OutputContext *outputs = graph->outputs(); + ASSERT_EQ(outputs->size(), 1); + loco::GraphOutput *output = outputs->at(0); + loco::Push *push = output->node(); + + loco::Graph::NodeContext *nodes = graph->nodes(); + ASSERT_EQ(nodes->size(), 2); + loco::ConstGen *node0 = dynamic_cast(nodes->at(0)); + ASSERT_NE(node0, nullptr); + loco::Push *node1 = dynamic_cast(nodes->at(1)); + ASSERT_EQ(node1, push); + + ASSERT_EQ(node0->size(), 6); + ASSERT_EQ(node0->at(0), 1.1f); + ASSERT_EQ(node0->at(1), 2.2f); + ASSERT_EQ(node0->at(2), 3.3f); + ASSERT_EQ(node0->at(3), 4.4f); + ASSERT_EQ(node0->at(4), 5.5f); + ASSERT_EQ(node0->at(5), 6.6f); +} -- 2.7.4