[moco-tf] Canonicalize TFConcatV2 (#6703)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 20 Aug 2019 07:08:23 +0000 (16:08 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 20 Aug 2019 07:08:23 +0000 (16:08 +0900)
* [moco-tf] Canonicalize TFConcatV2

This will implement Canonicalize of TFConcatV2 to TensorConcat

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* fix log msg

compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp

index e3282f3..e3939ad 100644 (file)
 
 #include "ConcatV2Canonicalizer.h"
 
+#include "LogHelper.h"
+
+#include "Annotations/ConcatData.h"
+#include "Annotations/ShapeInferenceData.h"
+
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
 
-#include <stdexcept>
+#include <moco/Log.h>
+
+#include <stdex/Memory.h>
 
 namespace
 {
 
+using namespace moco::tf;
+
 bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
 {
-  (void)graph;
-  (void)node;
+  LOGGER(l);
+
+  /**
+   * @note This will replace TFConcatV2 node with (series of) Canonical
+   *       TensorConcat. Below diagram is an example of three inputs
+   *
+   *       Before
+   *                 A --- TFConcatV2 -- C
+   *                 B --/
+   *                 N --/
+   *                 X --/
+   *       After
+   *                 A --- TFConcatV2
+   *                 B --/
+   *                 N --/
+   *                 X --/
+   *                 A --- TensorConcat -- TensorConcat -- C
+   *                 B --/               /
+   *                 N -----------------/
+   *
+   *       Where
+   *                 A : first value of TFConcatV2
+   *                 B : second value of TFConcatV2
+   *                 N : third or N'th value of TFConcatV2
+   *                 X : axis node of TFConcatV2
+   *                 C : a node that uses TFConcatV2 as an input
+   *                 TFConcatV2 is disconnected from C
+   *                 To simplify the diagram in 'After', A, B, N are drawn
+   *                 multiple times but they are same nodes.
+   */
+
+  const int num_values = node->num_values();
+  assert(num_values >= 2);
+
+  // get axis value
+  auto concat_data = node->annot<ConcatData>();
+  assert(concat_data != nullptr);
+  auto axis_value = concat_data->axis();
+
+  auto shapedata = node->annot<ShapeInferenceData>();
+  auto node_rank = shapedata->rank();
+
+  uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value;
+
+  INFO(l) << "canonicalize_concat axis(" << axis_absolute << "), value(" << axis_value << "), rank("
+          << node_rank << ")";
+
+  auto concat_node = graph->nodes()->create<loco::TensorConcat>();
+  concat_node->lhs(node->values(0));
+  concat_node->rhs(node->values(1));
+  concat_node->axis(axis_absolute);
+
+  loco::TensorConcat *last_concat = concat_node;
+  for (int ni = 2; ni < num_values; ++ni)
+  {
+    auto concat_node_next = graph->nodes()->create<loco::TensorConcat>();
+
+    concat_node_next->lhs(last_concat);
+    concat_node_next->rhs(node->values(ni));
+    concat_node_next->axis(axis_absolute);
+
+    // update last concat node
+    last_concat = concat_node_next;
+  }
 
-  throw std::runtime_error("NYI ConcatV2Canonicalizer");
+  // replace node
+  replace(node).with(last_concat);
 
-  return false;
+  return true;
 }
 
 } // namespace