[moco-tf] Introduce TFConcatV2 IR (#6321)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 7 Aug 2019 09:56:12 +0000 (18:56 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 7 Aug 2019 09:56:12 +0000 (18:56 +0900)
This will introduce TFConcatV2 IR and required codes for TensorFlow Concat node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Dialect/TFNodes.h
compiler/moco-tf/src/Dialect/TFNodes.lst
compiler/moco-tf/src/IR/TFConcatV2.h [new file with mode: 0644]
compiler/moco-tf/src/IR/TFConcatV2.test.cpp [new file with mode: 0644]
compiler/moco-tf/src/Transforms/FixPaddingTransform.cpp
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 0185a6f..50eebbe 100644 (file)
@@ -20,6 +20,7 @@
 #include "IR/TFAdd.h"
 #include "IR/TFAvgPool.h"
 #include "IR/TFBiasAdd.h"
+#include "IR/TFConcatV2.h"
 #include "IR/TFConst.h"
 #include "IR/TFConv2D.h"
 #include "IR/TFDepthwiseConv2dNative.h"
index 41dce05..5307e1e 100644 (file)
@@ -10,6 +10,7 @@
 TENSORFLOW_NODE(Add, TFAdd)
 TENSORFLOW_NODE(AvgPool, TFAvgPool)
 TENSORFLOW_NODE(BiasAdd, TFBiasAdd)
+TENSORFLOW_NODE(ConcatV2, TFConcatV2)
 TENSORFLOW_NODE(Const, TFConst)
 TENSORFLOW_NODE(Conv2D, TFConv2D)
 TENSORFLOW_NODE(DepthwiseConv2dNative, TFDepthwiseConv2dNative)
diff --git a/compiler/moco-tf/src/IR/TFConcatV2.h b/compiler/moco-tf/src/IR/TFConcatV2.h
new file mode 100644 (file)
index 0000000..2dd2214
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFCONCATV2_H__
+#define __MOCO_TF_IR_TFCONCATV2_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFConcatV2 corresponds to the following GraphDef
+/*
+node {
+  name: "Concat"
+  op: "ConcatV2"
+  input: "Input01"
+  input: "Input02"
+  input: "Axis"
+  attr {
+    key: "N"
+    value {
+      i: 2
+    }
+  }
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "Tidx"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+*/
+
+/**
+ * @note  As there is no VariableArityNode for now, we will import ConcatV2
+ *        as cascading of multiple TFConcatV2 nodes like loco::TensorConcat
+ */
+
+class TFConcatV2 final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::ConcatV2>>
+{
+public:
+  TFConcatV2() = default;
+
+public:
+  Node *lhs(void) const { return at(0)->node(); }
+  void lhs(Node *node) { at(0)->node(node); }
+
+  Node *rhs(void) const { return at(1)->node(); }
+  void rhs(Node *node) { at(1)->node(node); }
+
+public:
+  uint32_t axis(void) const { return _axis; }
+  void axis(uint32_t val) { _axis = val; }
+
+private:
+  // Axis
+  uint32_t _axis{0};
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFCONCATV2_H__
diff --git a/compiler/moco-tf/src/IR/TFConcatV2.test.cpp b/compiler/moco-tf/src/IR/TFConcatV2.test.cpp
new file mode 100644 (file)
index 0000000..85ea264
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFConcatV2.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFConcatV2Test, constructor)
+{
+  moco::tf::TFConcatV2 concatv2_node;
+
+  ASSERT_EQ(concatv2_node.dialect(), moco::tf::TFDialect::get());
+  ASSERT_EQ(concatv2_node.opcode(), moco::tf::TFOpcode::ConcatV2);
+
+  ASSERT_EQ(concatv2_node.lhs(), nullptr);
+  ASSERT_EQ(concatv2_node.rhs(), nullptr);
+  ASSERT_EQ(concatv2_node.axis(), 0);
+}
index e098e7b..f402683 100644 (file)
@@ -423,6 +423,12 @@ bool fix_padding(moco::tf::TFBiasAdd *node)
   return false;
 }
 
+bool fix_padding(moco::tf::TFConcatV2 *node)
+{
+  // Nothing to do with padding
+  return false;
+}
+
 bool fix_padding(moco::tf::TFConst *node)
 {
   // Nothing to do with padding
index 95459da..85c2c5c 100644 (file)
@@ -816,6 +816,70 @@ bool fix_shape(moco::tf::TFBiasAdd *node)
   return copy_shapedata(value, node);
 }
 
+bool fix_shape(moco::tf::TFConcatV2 *node)
+{
+  auto concat_data = node->annot<ConcatData>();
+  if (concat_data == nullptr)
+  {
+    // shape inference is already done for TFConcatV2
+    assert(node->annot<ShapeInferenceData>() != nullptr);
+    return false;
+  }
+  assert(node->annot<ShapeInferenceData>() == nullptr);
+
+  auto lhs = node->lhs();
+  auto rhs = node->rhs();
+  auto lhs_shapedata = lhs->annot<ShapeInferenceData>();
+  auto rhs_shapedata = rhs->annot<ShapeInferenceData>();
+  if (lhs_shapedata == nullptr || rhs_shapedata == nullptr)
+  {
+    // postpone as previous input node(s) hasn't been processed.
+    // this will return false as there was nothing changed, but the input of this
+    // node should be changed and from that this method should be called again.
+    // if not, the network may have some problem and the final output may not have
+    // the right shape value and we can identify with some validation at final stage.
+    return false;
+  }
+
+  uint32_t lhs_rank = lhs_shapedata->rank();
+  uint32_t rhs_rank = rhs_shapedata->rank();
+  assert(lhs_rank == rhs_rank);
+
+  int32_t axis_tf = concat_data->axis();
+  if (axis_tf < 0)
+  {
+    axis_tf = static_cast<int32_t>(lhs_rank) + axis_tf;
+  }
+  assert(0 <= axis_tf && axis_tf < static_cast<int32_t>(lhs_rank));
+  // clear annotation ConcatData
+  node->annot<ConcatData>(nullptr);
+
+  uint32_t axis_loco = static_cast<uint32_t>(axis_tf);
+  node->axis(axis_loco);
+
+  // Set ShapeInferenceData for TensorConcat
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+  shape_data->rank(lhs_rank);
+  for (uint32_t index = 0; index < lhs_rank; ++index)
+  {
+    uint32_t lhs_dim = lhs_shapedata->dim(index).value();
+    uint32_t rhs_dim = rhs_shapedata->dim(index).value();
+    // "lhs_dim == rhs_dim" should hold when "index != axis_loco"
+    // or doesn't care when "index == axis_loco"
+    assert(index == axis_loco || lhs_dim == rhs_dim);
+
+    uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim;
+
+    if (lhs_shapedata->dim(index).known())
+      shape_data->dim(index) = new_dim;
+    else
+      shape_data->dim(index).unset();
+  }
+  node->annot(std::move(shape_data));
+
+  return true;
+}
+
 bool fix_shape(moco::tf::TFConst *node)
 {
   auto shapedata = node->annot<ShapeInferenceData>();