#include "ConcatV2Canonicalizer.h"
+#include "LogHelper.h"
+
+#include "Annotations/ConcatData.h"
+#include "Annotations/ShapeInferenceData.h"
+
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
#include "Dialect/TFNodeVisitor.h"
#include "Dialect/TFNodeImpl.h"
-#include <stdexcept>
+#include <moco/Log.h>
+
+#include <stdex/Memory.h>
namespace
{
+using namespace moco::tf;
+
bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
{
- (void)graph;
- (void)node;
+ LOGGER(l);
+
+ /**
+ * @note This will replace TFConcatV2 node with (series of) Canonical
+ * TensorConcat. Below diagram is an example of three inputs
+ *
+ * Before
+ * A --- TFConcatV2 -- C
+ * B --/
+ * N --/
+ * X --/
+ * After
+ * A --- TFConcatV2
+ * B --/
+ * N --/
+ * X --/
+ * A --- TensorConcat -- TensorConcat -- C
+ * B --/ /
+ * N -----------------/
+ *
+ * Where
+ * A : first value of TFConcatV2
+ * B : second value of TFConcatV2
+ * N : third or N'th value of TFConcatV2
+ * X : axis node of TFConcatV2
+ * C : a node that uses TFConcatV2 as an input
+ * TFConcatV2 is disconnected from C
+ * To simplify the diagram in 'After', A, B, N are drawn
+ * multiple times but they are same nodes.
+ */
+
+ const int num_values = node->num_values();
+ assert(num_values >= 2);
+
+ // get axis value
+ auto concat_data = node->annot<ConcatData>();
+ assert(concat_data != nullptr);
+ auto axis_value = concat_data->axis();
+
+ auto shapedata = node->annot<ShapeInferenceData>();
+ auto node_rank = shapedata->rank();
+
+ uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value;
+
+ INFO(l) << "canonicalize_concat axis(" << axis_absolute << "), value(" << axis_value << "), rank("
+ << node_rank << ")";
+
+ auto concat_node = graph->nodes()->create<loco::TensorConcat>();
+ concat_node->lhs(node->values(0));
+ concat_node->rhs(node->values(1));
+ concat_node->axis(axis_absolute);
+
+ loco::TensorConcat *last_concat = concat_node;
+ for (int ni = 2; ni < num_values; ++ni)
+ {
+ auto concat_node_next = graph->nodes()->create<loco::TensorConcat>();
+
+ concat_node_next->lhs(last_concat);
+ concat_node_next->rhs(node->values(ni));
+ concat_node_next->axis(axis_absolute);
+
+ // update last concat node
+ last_concat = concat_node_next;
+ }
- throw std::runtime_error("NYI ConcatV2Canonicalizer");
+ // replace node
+ replace(node).with(last_concat);
- return false;
+ return true;
}
} // namespace