--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_NODESHAPE_H__
+#define __MOCO_TF_NODESHAPE_H__
+
+#include <loco.h>
+
+#include <memory.h>
+
+namespace loco_tobe
+{
+
+// Temporary BiasShape that loco will provide in near future
+using BiasShape = loco::Dimension;
+
+} // namespace loco_tobe
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief NodeShape provides shape information of a node
+ */
+class NodeShape final
+{
+public:
+ NodeShape(const loco::TensorShape &shape);
+ NodeShape(const loco::FeatureShape &shape);
+ NodeShape(const loco::FilterShape &shape);
+ NodeShape(const loco_tobe::BiasShape &shape);
+
+ ~NodeShape() = default;
+
+public:
+ loco::Domain domain(void) const { return _domain; }
+
+ loco::TensorShape tensor_shape(void) const;
+ loco::FeatureShape feature_shape(void) const;
+ loco::FilterShape filter_shape(void) const;
+ loco_tobe::BiasShape bias_shape(void) const;
+
+private:
+ loco::Domain _domain{loco::Domain::Unknown};
+
+ loco::TensorShape _tensor;
+};
+
+/**
+ * @brief node_shape() will return NodeShape object for a node
+ *
+ * @note NodeShape will be availale if node has shape information exist
+ * or return nullptr if not, i.e. shape inference was not executed or
+ * there was some missing information.
+ */
+std::unique_ptr<NodeShape> node_shape(loco::Node *node);
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_NODESHAPE_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <moco/tf/NodeShape.h>
+
+#include "Annotations/ShapeInferenceData.h"
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+NodeShape::NodeShape(const loco::TensorShape &shape) : _domain{loco::Domain::Tensor}
+{
+ _tensor = shape;
+}
+
+NodeShape::NodeShape(const loco::FeatureShape &shape) : _domain{loco::Domain::Feature}
+{
+ _tensor.rank(4);
+ _tensor.dim(0) = shape.count();
+ _tensor.dim(1) = shape.height();
+ _tensor.dim(2) = shape.width();
+ _tensor.dim(3) = shape.depth();
+}
+
+NodeShape::NodeShape(const loco::FilterShape &shape) : _domain{loco::Domain::Filter}
+{
+ _tensor.rank(4);
+ _tensor.dim(0) = shape.count();
+ _tensor.dim(1) = shape.height();
+ _tensor.dim(2) = shape.width();
+ _tensor.dim(3) = shape.depth();
+}
+
+NodeShape::NodeShape(const loco_tobe::BiasShape &shape) : _domain{loco::Domain::Bias}
+{
+ _tensor.rank(1);
+ _tensor.dim(0) = shape;
+}
+
+loco::TensorShape NodeShape::tensor_shape(void) const
+{
+ assert(_domain == loco::Domain::Tensor);
+
+ return _tensor;
+}
+
+loco::FeatureShape NodeShape::feature_shape(void) const
+{
+ assert(_domain == loco::Domain::Feature);
+
+ loco::FeatureShape shape;
+
+ shape.count() = _tensor.dim(0);
+ shape.height() = _tensor.dim(1);
+ shape.width() = _tensor.dim(2);
+ shape.depth() = _tensor.dim(3);
+
+ return shape;
+}
+
+loco::FilterShape NodeShape::filter_shape(void) const
+{
+ assert(_domain == loco::Domain::Filter);
+
+ loco::FilterShape shape;
+
+ shape.count() = _tensor.dim(0);
+ shape.height() = _tensor.dim(1);
+ shape.width() = _tensor.dim(2);
+ shape.depth() = _tensor.dim(3);
+
+ return shape;
+}
+
+loco_tobe::BiasShape NodeShape::bias_shape(void) const
+{
+ assert(_domain == loco::Domain::Bias);
+
+ loco_tobe::BiasShape shape;
+
+ shape = _tensor.dim(0);
+
+ return shape;
+}
+
+std::unique_ptr<NodeShape> node_shape(loco::Node *node)
+{
+ assert(node != nullptr);
+
+ auto shapedata = node->annot<ShapeInferenceData>();
+ if (shapedata != nullptr)
+ {
+ switch (shapedata->domain())
+ {
+ case loco::Domain::Tensor:
+ {
+ loco::TensorShape shape = shapedata->tensor_shape();
+ std::unique_ptr<NodeShape> node_shape = stdex::make_unique<NodeShape>(shape);
+ return std::move(node_shape);
+ }
+ break;
+
+ case loco::Domain::Feature:
+ {
+ loco::FeatureShape shape = shapedata->feature_shape();
+ std::unique_ptr<NodeShape> node_shape = stdex::make_unique<NodeShape>(shape);
+ return std::move(node_shape);
+ }
+ break;
+
+ case loco::Domain::Filter:
+ {
+ loco::FilterShape shape = shapedata->filter_shape();
+ std::unique_ptr<NodeShape> node_shape = stdex::make_unique<NodeShape>(shape);
+ return std::move(node_shape);
+ }
+ break;
+
+ case loco::Domain::Bias:
+ {
+ loco_tobe::BiasShape shape = shapedata->bias_shape();
+ std::unique_ptr<NodeShape> node_shape = stdex::make_unique<NodeShape>(shape);
+ return std::move(node_shape);
+ }
+ break;
+
+ default:
+ throw std::runtime_error("Not supported loco::Domain");
+ }
+ }
+
+ return nullptr;
+}
+
+} // namespace tf
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <moco/tf/NodeShape.h>
+
+#include "Annotations/ShapeInferenceData.h"
+
+#include <stdex/Memory.h>
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(TensorFlowFrontend, nodeshape_tensor_access)
+{
+ loco::TensorShape source;
+
+ source.rank(3);
+ source.dim(0) = 1;
+ source.dim(1) = 2;
+ source.dim(2) = 3;
+
+ moco::tf::NodeShape nodeshape(source);
+
+ loco::TensorShape tensor = nodeshape.tensor_shape();
+
+ ASSERT_EQ(tensor.rank(), 3);
+ ASSERT_EQ(tensor.dim(0), 1);
+ ASSERT_EQ(tensor.dim(1), 2);
+ ASSERT_EQ(tensor.dim(2), 3);
+}
+
+TEST(TensorFlowFrontend, nodeshape_feature_access)
+{
+ loco::FeatureShape source;
+
+ source.count() = 1;
+ source.height() = 2;
+ source.width() = 3;
+ source.depth() = 4;
+
+ moco::tf::NodeShape nodeshape(source);
+
+ loco::FeatureShape feature = nodeshape.feature_shape();
+
+ ASSERT_EQ(feature.count(), 1);
+ ASSERT_EQ(feature.height(), 2);
+ ASSERT_EQ(feature.width(), 3);
+ ASSERT_EQ(feature.depth(), 4);
+}
+
+TEST(TensorFlowFrontend, nodeshape_filter_access)
+{
+ loco::FilterShape source;
+
+ source.count() = 1;
+ source.height() = 2;
+ source.width() = 3;
+ source.depth() = 4;
+
+ moco::tf::NodeShape nodeshape(source);
+
+ loco::FilterShape filter = nodeshape.filter_shape();
+
+ ASSERT_EQ(filter.count(), 1);
+ ASSERT_EQ(filter.height(), 2);
+ ASSERT_EQ(filter.width(), 3);
+ ASSERT_EQ(filter.depth(), 4);
+}
+
+TEST(TensorFlowFrontend, nodeshape_bias_access)
+{
+ loco_tobe::BiasShape source;
+
+ source = 3;
+
+ moco::tf::NodeShape nodeshape(source);
+
+ loco_tobe::BiasShape bias = nodeshape.bias_shape();
+
+ ASSERT_EQ(bias.value(), 3);
+}
+
+TEST(TensorFlowFrontend, featureshape_from_shapeinfdata)
+{
+ // Prepare a node and annotate ShapeInferenceData
+ loco::Graph graph;
+ auto some_node = graph.nodes()->create<loco::AvgPool2D>();
+ loco::FeatureShape fshape;
+
+ fshape.count() = 1;
+ fshape.height() = 2;
+ fshape.width() = 3;
+ fshape.depth() = 4;
+
+ auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
+ shapedata->feature_shape(fshape);
+ some_node->annot(std::move(shapedata));
+
+ // Get NodeShape and check the values
+ std::unique_ptr<moco::tf::NodeShape> nodeshape = moco::tf::node_shape(some_node);
+ ASSERT_NE(nodeshape.get(), nullptr);
+ ASSERT_EQ(nodeshape->domain(), loco::Domain::Feature);
+
+ auto read_shape = nodeshape->feature_shape();
+ ASSERT_EQ(read_shape.count(), 1);
+ ASSERT_EQ(read_shape.height(), 2);
+ ASSERT_EQ(read_shape.width(), 3);
+ ASSERT_EQ(read_shape.depth(), 4);
+}