[moco/tf] Add shapeinfdata feature get,set (#3769)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 14 Jun 2019 04:44:31 +0000 (13:44 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 14 Jun 2019 04:44:31 +0000 (13:44 +0900)
* [moco/tf] Add shapeinfdata feature get,set

This will add getter and setter with FeatureShape for ShapeInferenceData

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* remove unused, fix comment

* fix typo

contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp [new file with mode: 0644]
contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.h
contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp [new file with mode: 0644]

diff --git a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp
new file mode 100644 (file)
index 0000000..25053e2
--- /dev/null
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInferenceData.h"
+
+namespace moco
+{
+namespace tf
+{
+
+loco::FeatureShape ShapeInferenceData::feature_shape(void) const
+{
+  loco::FeatureShape shape;
+
+  assert(rank() == 4);
+  if (rank() != 4)
+    return shape;
+
+  shape.count() = dim(0);
+  shape.height() = dim(1);
+  shape.width() = dim(2);
+  shape.depth() = dim(3);
+
+  return shape;
+}
+
+void ShapeInferenceData::feature_shape(const loco::FeatureShape &shape)
+{
+  rank(4);
+  dim(0) = shape.count();
+  dim(1) = shape.height();
+  dim(2) = shape.width();
+  dim(3) = shape.depth();
+}
+
+} // namespace tf
+} // namespace moco
index f502d0b..8ea7dad 100644 (file)
@@ -19,6 +19,8 @@
 
 #include <loco.h>
 
+#include <cassert>
+
 namespace moco
 {
 namespace tf
@@ -26,12 +28,19 @@ namespace tf
 
 /**
  * @brief ShapeInferenceData provides shape inference data tracking from the start(input)
-*/
+ *
+ * @note  For Feature, NHWC is used for shape layout
+ */
 class ShapeInferenceData : public loco::NodeAnnotation,
                            public loco::NodeMixin<loco::NodeTrait::TensorShape>
 {
 public:
   ~ShapeInferenceData(){};
+
+public:
+  loco::FeatureShape feature_shape(void) const;
+
+  void feature_shape(const loco::FeatureShape &shape);
 };
 
 } // namespace tf
diff --git a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp
new file mode 100644 (file)
index 0000000..fccf29f
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInferenceData.h"
+
+#include <gtest/gtest.h>
+
+TEST(TensorFlowFrontend, shapeinferencedata_feature_set)
+{
+  loco::FeatureShape feature;
+
+  feature.count() = 1;
+  feature.height() = 2;
+  feature.width() = 3;
+  feature.depth() = 4;
+
+  moco::tf::ShapeInferenceData shapedata;
+
+  shapedata.feature_shape(feature);
+
+  ASSERT_EQ(shapedata.rank(), 4);
+  ASSERT_EQ(shapedata.dim(0), 1);
+  ASSERT_EQ(shapedata.dim(1), 2);
+  ASSERT_EQ(shapedata.dim(2), 3);
+  ASSERT_EQ(shapedata.dim(3), 4);
+}
+
+TEST(TensorFlowFrontend, shapeinferencedata_feature_get)
+{
+  moco::tf::ShapeInferenceData shapedata;
+
+  shapedata.rank(4);
+  shapedata.dim(0) = 1;
+  shapedata.dim(1) = 2;
+  shapedata.dim(2) = 3;
+  shapedata.dim(3) = 4;
+
+  loco::FeatureShape feature = shapedata.feature_shape();
+
+  ASSERT_EQ(feature.count(), 1);
+  ASSERT_EQ(feature.height(), 2);
+  ASSERT_EQ(feature.width(), 3);
+  ASSERT_EQ(feature.depth(), 4);
+}