From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Fri, 19 Jul 2019 04:22:30 +0000 (+0900) Subject: [moco/tf] Record Graph-Level Input/Output Shape (#4357) X-Git-Tag: nncc_backup~1 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a282892331fec1592976b1da485f23800c1f5492;p=platform%2Fcore%2Fml%2Fnnfw.git [moco/tf] Record Graph-Level Input/Output Shape (#4357) moco.tf frontend now records graph-level input/output shape. Signed-off-by: Jonghyun Park --- diff --git a/compiler/moco-tf/src/Frontend.cpp b/compiler/moco-tf/src/Frontend.cpp index 45a741c..633d883 100644 --- a/compiler/moco-tf/src/Frontend.cpp +++ b/compiler/moco-tf/src/Frontend.cpp @@ -101,6 +101,20 @@ void load_tf(std::istream *stream, moco::tf::Frontend::FileType type, } // namespace +// TODO Find a proper place for this function +#include "Annotations/ShapeInferenceData.h" + +namespace +{ + +loco::TensorShape tensor_shape(loco::Node *node) +{ + assert(node->annot() != nullptr); + return node->annot()->tensor_shape(); +} + +} // namespace + namespace moco { namespace tf @@ -161,6 +175,21 @@ std::unique_ptr Frontend::import(const ModelSignature &signature, // Transform TFNodes tfoptimizier.optimize(graph.get()); + // Fill graph-level input/output shape + // + // ASSUMPTION! All the shapes are known at this point + for (uint32_t n = 0; n < graph->inputs()->size(); ++n) + { + auto input = graph->inputs()->at(n); + input->shape(stdex::make_unique(tensor_shape(input->node()))); + } + + for (uint32_t n = 0; n < graph->outputs()->size(); ++n) + { + auto output = graph->outputs()->at(n); + output->shape(stdex::make_unique(tensor_shape(output->node()))); + } + // Convert graph to hold only Canonical dialect Canonicalizer canonicalizer; diff --git a/compiler/moco-tf/src/Frontend.test.cpp b/compiler/moco-tf/src/Frontend.test.cpp new file mode 100644 index 0000000..57a9eb7 --- /dev/null +++ b/compiler/moco-tf/src/Frontend.test.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moco/tf/Frontend.h" + +#include "TestHelper.h" + +#include + +#include + +namespace +{ + +// clang-format off +const char *pbtxt_000 = STRING_CONTENT( +node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "dtype" + value { type: DT_FLOAT } + } + attr { + key: "shape" + value { + shape { + dim { size: 4 } + } + } + } +} +node { + name: "Identity" + op: "Identity" + input: "Placeholder" + attr { + key: "T" + value { type: DT_FLOAT } + } +} +); +// clang-format on + +} // namespace + +TEST(FrontendTests, testcase_000) +{ + moco::tf::Frontend frontend; + moco::tf::ModelSignature signature; + + signature.add_input(moco::tf::TensorName("Placeholder", 0)); + signature.add_output(moco::tf::TensorName("Identity", 0)); + + std::stringstream ss{pbtxt_000}; + + auto graph = frontend.load(signature, &ss, moco::tf::Frontend::FileType::Text); + + ASSERT_EQ(graph->inputs()->size(), 1); + ASSERT_EQ(graph->inputs()->at(0)->name(), "Placeholder"); + ASSERT_NE(graph->inputs()->at(0)->shape(), nullptr); + ASSERT_EQ(graph->inputs()->at(0)->shape()->rank(), 1); + ASSERT_EQ(graph->inputs()->at(0)->shape()->dim(0), 4); + + ASSERT_EQ(graph->outputs()->size(), 1); + ASSERT_EQ(graph->outputs()->at(0)->name(), "Identity"); + ASSERT_NE(graph->outputs()->at(0)->shape(), nullptr); + ASSERT_EQ(graph->outputs()->at(0)->shape()->rank(), 1); + ASSERT_EQ(graph->outputs()->at(0)->shape()->dim(0), 4); +}