[moco-tf] Add Pad canonicalizer (#8030)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Mon, 14 Oct 2019 01:52:06 +0000 (10:52 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 14 Oct 2019 01:52:06 +0000 (10:52 +0900)
* [moco-tf] Add Pad canonicalizer

This commit adds Pad canonicalizer to moco-tf.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* apply commnets.

* apply comment.

* Simplify method to get rank of padding.

* fix wrong assert.

* add comments in canonicalization.

* make comment more clear.

compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp [new file with mode: 0644]
compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h [new file with mode: 0644]
compiler/moco-tf/src/Canonicalizer.cpp

diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
new file mode 100644 (file)
index 0000000..7f9db61
--- /dev/null
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PadCanonicalizer.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
+
+#include "loco/Service/TypeInference.h"
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+bool canonicalize_pad(loco::Graph *graph, moco::tf::TFPad *node)
+{
+  /**
+   * @note This will replace TFPad node with Canonical TensorConstantPad
+   *
+   *       Before
+   *                 input --- TFPad -- C
+   *                 paddings --/
+   *       After
+   *                 paddings  ------- TFPad --
+   *                                  /
+   *                 input ----------- TensorConstantPad -- C
+   *                 ConstGen --------/
+   *       Where
+   *                 input : input of TFPad
+   *                 paddings : paddings of TFPad. it becomes TensorConstantPad's attribute.
+   *                 C : a node that uses TFPad as an input. TFPad is disconnected from C.
+   *                 ConstGen : constant value of Pad. TFPad has zero value by default.
+   */
+
+  auto pad_node = graph->nodes()->create<loco::TensorConstantPad>();
+
+  auto constant_node = graph->nodes()->create<loco::ConstGen>();
+
+  auto input_node = node->input();
+  // TODO: support other dtype.
+  assert(loco::dtype_get(input_node) == loco::DataType::FLOAT32);
+  constant_node->dtype(loco::DataType::FLOAT32);
+  constant_node->rank(0);
+  constant_node->size<loco::DataType::FLOAT32>(1);
+  constant_node->at<loco::DataType::FLOAT32>(0) = 0.0f;
+
+  auto const_paddings_node = dynamic_cast<loco::ConstGen *>(node->paddings());
+  // TODO: support S64 type.
+  assert(const_paddings_node->dtype() == loco::DataType::S32);
+  assert(const_paddings_node->rank() == 2);
+  assert(const_paddings_node->dim(1).value() == 2);
+
+  auto padding = pad_node->padding();
+  uint32_t padding_rank = const_paddings_node->dim(0).value();
+  padding->rank(padding_rank);
+
+  for (uint32_t i = 0; i < padding_rank; i++)
+  {
+    padding->front(i) = const_paddings_node->at<loco::DataType::S32>(i << 1);
+    padding->back(i) = const_paddings_node->at<loco::DataType::S32>((i << 1) + 1);
+  }
+
+  // update connections
+  pad_node->input(input_node);
+  pad_node->constant(constant_node);
+
+  // replace node
+  replace(node).with(pad_node);
+
+  return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool PadCanonicalizer::transform(TFPad *node) const
+{
+  return canonicalize_pad(node->graph(), node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h
new file mode 100644 (file)
index 0000000..7909da0
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_PAD_CANONICALIZER_H__
+#define __MOCO_TF_PAD_CANONICALIZER_H__
+
+#include "Transform.h"
+#include "SimpleNodeTransform.h"
+#include "Dialect/TFNodes.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief  Convert TFPad to Canonical TensorConstantPad
+ */
+class PadCanonicalizer final : public SimpleNodeTransform<TFPad>
+{
+public:
+  const char *name(void) const final { return "PadCanonicalizer"; }
+
+public:
+  bool transform(TFPad *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_PAD_CANONICALIZER_H__
index c81aad0..7c2ea0c 100644 (file)
@@ -34,6 +34,7 @@
 #include "Canonicalization/MaxPoolCanonicalizer.h"
 #include "Canonicalization/MeanCanonicalizer.h"
 #include "Canonicalization/MulCanonicalizer.h"
+#include "Canonicalization/PadCanonicalizer.h"
 #include "Canonicalization/RealDivCanonicalizer.h"
 #include "Canonicalization/ReluCanonicalizer.h"
 #include "Canonicalization/Relu6Canonicalizer.h"
@@ -105,6 +106,7 @@ void Canonicalizer::canonicalize(loco::Graph *g) const
   phase.emplace_back(stdex::make_unique<MaxPoolCanonicalizer>());
   phase.emplace_back(stdex::make_unique<MeanCanonicalizer>());
   phase.emplace_back(stdex::make_unique<MulCanonicalizer>());
+  phase.emplace_back(stdex::make_unique<PadCanonicalizer>());
   phase.emplace_back(stdex::make_unique<RealDivCanonicalizer>());
   phase.emplace_back(stdex::make_unique<ReluCanonicalizer>());
   phase.emplace_back(stdex::make_unique<Relu6Canonicalizer>());