From ee9f1306a814858ebbc26c85d1baf827ab8abb04 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 24 Jun 2019 10:20:58 +0900 Subject: [PATCH] [moco/tf] Introduce NodeShape (#3926) * [moco/tf] Introduce NodeShape This will introduce NodeShape to access shape inference result(shape information) of each loco::Node Signed-off-by: SaeHie Park * apply comments * apply another comment --- .../lib/frontend/tf/include/moco/tf/NodeShape.h | 76 ++++++++++ contrib/moco/lib/frontend/tf/src/NodeShape.cpp | 155 +++++++++++++++++++++ .../moco/lib/frontend/tf/src/NodeShape.test.cpp | 122 ++++++++++++++++ 3 files changed, 353 insertions(+) create mode 100644 contrib/moco/lib/frontend/tf/include/moco/tf/NodeShape.h create mode 100644 contrib/moco/lib/frontend/tf/src/NodeShape.cpp create mode 100644 contrib/moco/lib/frontend/tf/src/NodeShape.test.cpp diff --git a/contrib/moco/lib/frontend/tf/include/moco/tf/NodeShape.h b/contrib/moco/lib/frontend/tf/include/moco/tf/NodeShape.h new file mode 100644 index 0000000..ceae8ae --- /dev/null +++ b/contrib/moco/lib/frontend/tf/include/moco/tf/NodeShape.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MOCO_TF_NODESHAPE_H__ +#define __MOCO_TF_NODESHAPE_H__ + +#include + +#include + +namespace loco_tobe +{ + +// Temporary BiasShape that loco will provide in near future +using BiasShape = loco::Dimension; + +} // namespace loco_tobe + +namespace moco +{ +namespace tf +{ + +/** + * @brief NodeShape provides shape information of a node + */ +class NodeShape final +{ +public: + NodeShape(const loco::TensorShape &shape); + NodeShape(const loco::FeatureShape &shape); + NodeShape(const loco::FilterShape &shape); + NodeShape(const loco_tobe::BiasShape &shape); + + ~NodeShape() = default; + +public: + loco::Domain domain(void) const { return _domain; } + + loco::TensorShape tensor_shape(void) const; + loco::FeatureShape feature_shape(void) const; + loco::FilterShape filter_shape(void) const; + loco_tobe::BiasShape bias_shape(void) const; + +private: + loco::Domain _domain{loco::Domain::Unknown}; + + loco::TensorShape _tensor; +}; + +/** + * @brief node_shape() will return NodeShape object for a node + * + * @note NodeShape will be availale if node has shape information exist + * or return nullptr if not, i.e. shape inference was not executed or + * there was some missing information. + */ +std::unique_ptr node_shape(loco::Node *node); + +} // namespace tf +} // namespace moco + +#endif // __MOCO_TF_NODESHAPE_H__ diff --git a/contrib/moco/lib/frontend/tf/src/NodeShape.cpp b/contrib/moco/lib/frontend/tf/src/NodeShape.cpp new file mode 100644 index 0000000..5952c00 --- /dev/null +++ b/contrib/moco/lib/frontend/tf/src/NodeShape.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "Annotations/ShapeInferenceData.h" + +#include + +#include + +namespace moco +{ +namespace tf +{ + +NodeShape::NodeShape(const loco::TensorShape &shape) : _domain{loco::Domain::Tensor} +{ + _tensor = shape; +} + +NodeShape::NodeShape(const loco::FeatureShape &shape) : _domain{loco::Domain::Feature} +{ + _tensor.rank(4); + _tensor.dim(0) = shape.count(); + _tensor.dim(1) = shape.height(); + _tensor.dim(2) = shape.width(); + _tensor.dim(3) = shape.depth(); +} + +NodeShape::NodeShape(const loco::FilterShape &shape) : _domain{loco::Domain::Filter} +{ + _tensor.rank(4); + _tensor.dim(0) = shape.count(); + _tensor.dim(1) = shape.height(); + _tensor.dim(2) = shape.width(); + _tensor.dim(3) = shape.depth(); +} + +NodeShape::NodeShape(const loco_tobe::BiasShape &shape) : _domain{loco::Domain::Bias} +{ + _tensor.rank(1); + _tensor.dim(0) = shape; +} + +loco::TensorShape NodeShape::tensor_shape(void) const +{ + assert(_domain == loco::Domain::Tensor); + + return _tensor; +} + +loco::FeatureShape NodeShape::feature_shape(void) const +{ + assert(_domain == loco::Domain::Feature); + + loco::FeatureShape shape; + + shape.count() = _tensor.dim(0); + shape.height() = _tensor.dim(1); + shape.width() = _tensor.dim(2); + shape.depth() = _tensor.dim(3); + + return shape; +} + +loco::FilterShape NodeShape::filter_shape(void) const +{ + assert(_domain == loco::Domain::Filter); + + loco::FilterShape shape; + + shape.count() = _tensor.dim(0); + shape.height() = _tensor.dim(1); + shape.width() = _tensor.dim(2); + shape.depth() = _tensor.dim(3); + + return shape; +} + +loco_tobe::BiasShape NodeShape::bias_shape(void) const +{ + assert(_domain == loco::Domain::Bias); + + loco_tobe::BiasShape shape; + + shape = _tensor.dim(0); + + return shape; +} + +std::unique_ptr node_shape(loco::Node *node) +{ + assert(node != nullptr); + + auto shapedata = node->annot(); + if (shapedata != nullptr) + { + switch (shapedata->domain()) + { + case loco::Domain::Tensor: + { + loco::TensorShape shape = shapedata->tensor_shape(); + std::unique_ptr node_shape = stdex::make_unique(shape); + return std::move(node_shape); + } + break; + + case loco::Domain::Feature: + { + loco::FeatureShape shape = shapedata->feature_shape(); + std::unique_ptr node_shape = stdex::make_unique(shape); + return std::move(node_shape); + } + break; + + case loco::Domain::Filter: + { + loco::FilterShape shape = shapedata->filter_shape(); + std::unique_ptr node_shape = stdex::make_unique(shape); + return std::move(node_shape); + } + break; + + case loco::Domain::Bias: + { + loco_tobe::BiasShape shape = shapedata->bias_shape(); + std::unique_ptr node_shape = stdex::make_unique(shape); + return std::move(node_shape); + } + break; + + default: + throw std::runtime_error("Not supported loco::Domain"); + } + } + + return nullptr; +} + +} // namespace tf +} // namespace moco diff --git a/contrib/moco/lib/frontend/tf/src/NodeShape.test.cpp b/contrib/moco/lib/frontend/tf/src/NodeShape.test.cpp new file mode 100644 index 0000000..d47f3f9 --- /dev/null +++ b/contrib/moco/lib/frontend/tf/src/NodeShape.test.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "Annotations/ShapeInferenceData.h" + +#include +#include + +#include + +TEST(TensorFlowFrontend, nodeshape_tensor_access) +{ + loco::TensorShape source; + + source.rank(3); + source.dim(0) = 1; + source.dim(1) = 2; + source.dim(2) = 3; + + moco::tf::NodeShape nodeshape(source); + + loco::TensorShape tensor = nodeshape.tensor_shape(); + + ASSERT_EQ(tensor.rank(), 3); + ASSERT_EQ(tensor.dim(0), 1); + ASSERT_EQ(tensor.dim(1), 2); + ASSERT_EQ(tensor.dim(2), 3); +} + +TEST(TensorFlowFrontend, nodeshape_feature_access) +{ + loco::FeatureShape source; + + source.count() = 1; + source.height() = 2; + source.width() = 3; + source.depth() = 4; + + moco::tf::NodeShape nodeshape(source); + + loco::FeatureShape feature = nodeshape.feature_shape(); + + ASSERT_EQ(feature.count(), 1); + ASSERT_EQ(feature.height(), 2); + ASSERT_EQ(feature.width(), 3); + ASSERT_EQ(feature.depth(), 4); +} + +TEST(TensorFlowFrontend, nodeshape_filter_access) +{ + loco::FilterShape source; + + source.count() = 1; + source.height() = 2; + source.width() = 3; + source.depth() = 4; + + moco::tf::NodeShape nodeshape(source); + + loco::FilterShape filter = nodeshape.filter_shape(); + + ASSERT_EQ(filter.count(), 1); + ASSERT_EQ(filter.height(), 2); + ASSERT_EQ(filter.width(), 3); + ASSERT_EQ(filter.depth(), 4); +} + +TEST(TensorFlowFrontend, nodeshape_bias_access) +{ + loco_tobe::BiasShape source; + + source = 3; + + moco::tf::NodeShape nodeshape(source); + + loco_tobe::BiasShape bias = nodeshape.bias_shape(); + + ASSERT_EQ(bias.value(), 3); +} + +TEST(TensorFlowFrontend, featureshape_from_shapeinfdata) +{ + // Prepare a node and annotate ShapeInferenceData + loco::Graph graph; + auto some_node = graph.nodes()->create(); + loco::FeatureShape fshape; + + fshape.count() = 1; + fshape.height() = 2; + fshape.width() = 3; + fshape.depth() = 4; + + auto shapedata = stdex::make_unique(); + shapedata->feature_shape(fshape); + some_node->annot(std::move(shapedata)); + + // Get NodeShape and check the values + std::unique_ptr nodeshape = moco::tf::node_shape(some_node); + ASSERT_NE(nodeshape.get(), nullptr); + ASSERT_EQ(nodeshape->domain(), loco::Domain::Feature); + + auto read_shape = nodeshape->feature_shape(); + ASSERT_EQ(read_shape.count(), 1); + ASSERT_EQ(read_shape.height(), 2); + ASSERT_EQ(read_shape.width(), 3); + ASSERT_EQ(read_shape.depth(), 4); +} -- 2.7.4