[mir2loco] Support Broadcast on Elementwise operations (#7665)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Mon, 30 Sep 2019 15:34:17 +0000 (18:34 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 30 Sep 2019 15:34:17 +0000 (18:34 +0300)
* Supported new elemetwise operations
* Implemented TensorBroadcast creation
* Fixed all functions for broadcasting

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/include/mir2loco.h
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index ff1faf0..71a58c7 100644 (file)
@@ -34,6 +34,7 @@ public:
   void visit(mir::ops::Conv2DOp &op) override;
   void visit(mir::ops::DeConv2DOp &op) override;
   void visit(mir::ops::DepthwiseConv2DOp &op) override;
+  void visit(mir::ops::DivOp &op) override;
   void visit(mir::ops::FullyConnectedOp &op) override;
   void visit(mir::ops::InputOp &op) override;
   void visit(mir::ops::MaxPool2DOp &op) override;
@@ -42,6 +43,7 @@ public:
   void visit(mir::ops::ReluOp &op) override;
   void visit(mir::ops::ReshapeOp &op) override;
   void visit(mir::ops::SoftmaxOp &op) override;
+  void visit(mir::ops::SubOp &op) override;
 
   void visit_fallback(mir::Operation &op) override;
 
index 87b86d7..250fd24 100644 (file)
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/Deconv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/DivOp.h"
 #include "mir/ops/FullyConnectedOp.h"
 #include "mir/ops/MaxPool2DOp.h"
 #include "mir/ops/MulOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
 #include "mir/ops/SoftmaxOp.h"
+#include "mir/ops/SubOp.h"
 
 #include "mir/ShapeRange.h"
 
@@ -223,6 +225,76 @@ loco::DataType convertDataType(mir::DataType data_type)
   }
   throw std::runtime_error("Unsupported data type");
 }
+
+loco::Node *createBroadcastIfNeeded(loco::Node *node, const mir::Shape &shape,
+                                    const mir::Shape &out_shape)
+{
+  auto graph = node->graph();
+
+  if (shape == out_shape)
+    return node; // not needed
+
+  int32_t out_rank = out_shape.rank();
+  int32_t rank_diff = out_rank - shape.rank();
+  // Create Broadcast
+  auto *broadcast = graph->nodes()->create<loco::TensorBroadcast>();
+  // Create Reshape for equal ranks
+  if (shape.rank() != out_rank)
+  {
+    auto *reshape = graph->nodes()->create<loco::FixedReshape>();
+    reshape->input(node);
+    reshape->rank(out_rank);
+    broadcast->input(reshape);
+    // Set reshape dims
+    for (int32_t dim = 0; dim < out_rank; dim++)
+    {
+      if (dim < rank_diff)
+        reshape->dim(dim) = 1;
+      else
+        reshape->dim(dim) = shape.dim(dim - rank_diff);
+    }
+  }
+  else
+  {
+    broadcast->input(node);
+  }
+  // Flag if no one dim isn't equal
+  bool compatible_shapes = true;
+  for (int32_t dim = 0; dim < out_rank; dim++)
+  {
+    // Set broadcast mapping
+    if (dim < rank_diff || (shape.dim(dim - rank_diff) == 1 && out_shape.dim(dim) != 1))
+      broadcast->mapping()->dim(dim) = out_shape.dim(dim);
+    // Check compatibility
+    if (dim >= rank_diff && shape.dim(dim - rank_diff) != 1 &&
+        shape.dim(dim - rank_diff) != out_shape.dim(dim))
+      compatible_shapes = false;
+  }
+  // Check compatibility
+  if (!compatible_shapes)
+    throw std::runtime_error("Not compatible shapes for broadcasting!");
+
+  return broadcast;
+}
+
+template <typename NodeType>
+NodeType *createEltwiseBinary(const mir::ops::BinaryElementwiseOp &op, loco::Node *lhs,
+                              loco::Node *rhs)
+{
+  auto graph = lhs->graph();
+
+  const auto &lhs_shape = op.getInput(0)->getProducer()->getShape();
+  const auto &rhs_shape = op.getInput(1)->getProducer()->getShape();
+  const auto &out_shape = op.getOutputShape(0);
+  // Create Broadcast if it's needed
+  auto lhs_node = createBroadcastIfNeeded(lhs, lhs_shape, out_shape);
+  auto rhs_node = createBroadcastIfNeeded(rhs, rhs_shape, out_shape);
+  // Create Node
+  auto result = graph->nodes()->create<NodeType>();
+  result->lhs(lhs_node);
+  result->rhs(rhs_node);
+  return result;
+}
 } // namespace
 
 void Transformer::visit(mir::ops::AddOp &op)
@@ -230,11 +302,7 @@ void Transformer::visit(mir::ops::AddOp &op)
   // Get Input
   auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
   auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
-
-  auto result = _loco_graph->nodes()->create<loco::EltwiseAdd>();
-  result->lhs(lhs);
-  result->rhs(rhs);
-
+  auto result = createEltwiseBinary<loco::EltwiseAdd>(op, lhs, rhs);
   // Not set Shape
   // Add to map
   _mir2loco_map.emplace(&op, result);
@@ -473,6 +541,17 @@ void Transformer::visit(mir::ops::DepthwiseConv2DOp &op)
   _mir2loco_map.emplace(&op, feature_dec);
 }
 
+void Transformer::visit(mir::ops::DivOp &op)
+{
+  // Get Input
+  auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
+  auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
+  auto result = createEltwiseBinary<loco::EltwiseDiv>(op, lhs, rhs);
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, result);
+}
+
 void Transformer::visit(mir::ops::FullyConnectedOp &op)
 {
   auto input = op.getInput(0)->getProducer()->getNode();
@@ -554,11 +633,7 @@ void Transformer::visit(mir::ops::MulOp &op)
   // Get Input
   auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
   auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
-
-  auto result = _loco_graph->nodes()->create<loco::EltwiseMul>();
-  result->lhs(lhs);
-  result->rhs(rhs);
-
+  auto result = createEltwiseBinary<loco::EltwiseMul>(op, lhs, rhs);
   // Not set Shape
   // Add to map
   _mir2loco_map.emplace(&op, result);
@@ -624,6 +699,17 @@ void Transformer::visit(mir::ops::SoftmaxOp &op)
   _mir2loco_map.emplace(&op, softmax_node);
 }
 
+void Transformer::visit(mir::ops::SubOp &op)
+{
+  // Get Input
+  auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
+  auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
+  auto result = createEltwiseBinary<loco::EltwiseSub>(op, lhs, rhs);
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, result);
+}
+
 void Transformer::visit_fallback(mir::Operation &op) { throw std::runtime_error("NYI operation"); }
 
 std::unique_ptr<loco::Graph> Transformer::transform(mir::Graph *mir_graph)
index 7ed0820..c4a52c4 100644 (file)
@@ -318,10 +318,8 @@ TEST_F(TestTransformer_mir2loco, Add_Test)
 {
   mir::Graph mir_graph;
 
-  mir::Shape input_shape{5, 6, 7, 3};
-
-  auto *input1 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
-  auto *input2 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
+  auto *input1 = mir_graph.create<mir::ops::InputOp>(mir::Shape{5, 6, 7, 3})->getOutput(0);
+  auto *input2 = mir_graph.create<mir::ops::InputOp>(mir::Shape{5, 1, 7, 3})->getOutput(0);
   auto *add = mir_graph.create<mir::ops::AddOp>(input1, input2)->getOutput(0);
   mir_graph.create<mir::ops::OutputOp>(add);
   input1->setName("x1");
@@ -331,32 +329,26 @@ TEST_F(TestTransformer_mir2loco, Add_Test)
   mir2loco::Transformer transformer;
   auto loco_graph = transformer.transform(&mir_graph);
 
-  auto *pull1_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
-  auto *pull2_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(1));
-  auto *add_node = dynamic_cast<loco::EltwiseAdd *>(loco_graph->nodes()->at(2));
-  auto *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(3));
-
-  ASSERT_NE(pull1_node, nullptr);
-  ASSERT_NE(pull2_node, nullptr);
+  // Pull
+  auto inputs = loco_graph->inputs();
+  ASSERT_EQ(inputs->size(), 2);
+  loco::Pull *pull_node0 = loco::pull_node(loco_graph.get(), 0);
+  ASSERT_NE(pull_node0, nullptr);
+  loco::Pull *pull_node1 = loco::pull_node(loco_graph.get(), 1);
+  ASSERT_NE(pull_node1, nullptr);
+  // Add
+  auto pull_uses = loco::succs(pull_node0);
+  ASSERT_EQ(pull_uses.size(), 1);
+  loco::EltwiseAdd *add_node = dynamic_cast<loco::EltwiseAdd *>(*pull_uses.begin());
   ASSERT_NE(add_node, nullptr);
-  ASSERT_NE(push_node, nullptr);
-
-  ASSERT_EQ(add_node->lhs(), pull1_node);
-  ASSERT_EQ(add_node->rhs(), pull2_node);
-  ASSERT_EQ(push_node->from(), add_node);
-
-  // Shape check
-  ASSERT_EQ(pull1_node->rank(), 4);
-  ASSERT_EQ(pull1_node->dim(0), 5);
-  ASSERT_EQ(pull1_node->dim(1), 6);
-  ASSERT_EQ(pull1_node->dim(2), 7);
-  ASSERT_EQ(pull1_node->dim(3), 3);
-
-  ASSERT_EQ(pull2_node->rank(), 4);
-  ASSERT_EQ(pull2_node->dim(0), 5);
-  ASSERT_EQ(pull2_node->dim(1), 6);
-  ASSERT_EQ(pull2_node->dim(2), 7);
-  ASSERT_EQ(pull2_node->dim(3), 3);
+  ASSERT_EQ(add_node->lhs(), pull_node0);
+  // TensorBroadcast
+  loco::TensorBroadcast *broadcast_node = dynamic_cast<loco::TensorBroadcast *>(add_node->rhs());
+  ASSERT_NE(broadcast_node, nullptr);
+  ASSERT_EQ(broadcast_node->input(), pull_node1);
+  // Check params
+  ASSERT_TRUE(broadcast_node->mapping()->defined(1));
+  ASSERT_EQ(broadcast_node->mapping()->dim(1), 6);
 }
 
 TEST_F(TestTransformer_mir2loco, Conv2D_Test)
@@ -446,10 +438,8 @@ TEST_F(TestTransformer_mir2loco, Mul_Test)
 {
   mir::Graph mir_graph;
 
-  mir::Shape input_shape{5, 6, 7, 3};
-
-  auto *input1 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
-  auto *input2 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
+  auto *input1 = mir_graph.create<mir::ops::InputOp>(mir::Shape{5, 6, 7, 13})->getOutput(0);
+  auto *input2 = mir_graph.create<mir::ops::InputOp>(mir::Shape{13})->getOutput(0);
   auto *add = mir_graph.create<mir::ops::MulOp>(input1, input2)->getOutput(0);
   mir_graph.create<mir::ops::OutputOp>(add);
   input1->setName("x1");
@@ -459,32 +449,47 @@ TEST_F(TestTransformer_mir2loco, Mul_Test)
   mir2loco::Transformer transformer;
   auto loco_graph = transformer.transform(&mir_graph);
 
-  auto *pull1_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
-  auto *pull2_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(1));
-  auto *add_node = dynamic_cast<loco::EltwiseMul *>(loco_graph->nodes()->at(2));
-  auto *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(3));
-
-  ASSERT_NE(pull1_node, nullptr);
-  ASSERT_NE(pull2_node, nullptr);
-  ASSERT_NE(add_node, nullptr);
-  ASSERT_NE(push_node, nullptr);
-
-  ASSERT_EQ(add_node->lhs(), pull1_node);
-  ASSERT_EQ(add_node->rhs(), pull2_node);
-  ASSERT_EQ(push_node->from(), add_node);
-
-  // Shape check
-  ASSERT_EQ(pull1_node->rank(), 4);
-  ASSERT_EQ(pull1_node->dim(0), 5);
-  ASSERT_EQ(pull1_node->dim(1), 6);
-  ASSERT_EQ(pull1_node->dim(2), 7);
-  ASSERT_EQ(pull1_node->dim(3), 3);
-
-  ASSERT_EQ(pull2_node->rank(), 4);
-  ASSERT_EQ(pull2_node->dim(0), 5);
-  ASSERT_EQ(pull2_node->dim(1), 6);
-  ASSERT_EQ(pull2_node->dim(2), 7);
-  ASSERT_EQ(pull2_node->dim(3), 3);
+  // Pulls
+  auto inputs = loco_graph->inputs();
+  ASSERT_EQ(inputs->size(), 2);
+  loco::Pull *pull_node0 = loco::pull_node(loco_graph.get(), 0);
+  ASSERT_NE(pull_node0, nullptr);
+  loco::Pull *pull_node1 = loco::pull_node(loco_graph.get(), 1);
+  ASSERT_NE(pull_node1, nullptr);
+  // Mul
+  auto pull0_uses = loco::succs(pull_node0);
+  ASSERT_EQ(pull0_uses.size(), 1);
+  loco::EltwiseMul *mul_node = dynamic_cast<loco::EltwiseMul *>(*pull0_uses.begin());
+  ASSERT_NE(mul_node, nullptr);
+  // Broadcast
+  loco::TensorBroadcast *broadcast_node = dynamic_cast<loco::TensorBroadcast *>(mul_node->rhs());
+  ASSERT_NE(broadcast_node, nullptr);
+  ASSERT_EQ(mul_node->lhs(), pull_node0);
+  ASSERT_EQ(mul_node->rhs(), broadcast_node);
+  loco::FixedReshape *reshape_node = dynamic_cast<loco::FixedReshape *>(broadcast_node->input());
+  ASSERT_NE(reshape_node, nullptr);
+  ASSERT_EQ(reshape_node->input(), pull_node1);
+  ASSERT_EQ(reshape_node->rank(), 4);
+  ASSERT_EQ(reshape_node->dim(0), 1);
+  ASSERT_EQ(reshape_node->dim(1), 1);
+  ASSERT_EQ(reshape_node->dim(2), 1);
+  ASSERT_EQ(reshape_node->dim(3), 13);
+  // Params checks
+  ASSERT_EQ(pull_node0->rank(), 4);
+  ASSERT_EQ(pull_node0->dim(0), 5);
+  ASSERT_EQ(pull_node0->dim(1), 6);
+  ASSERT_EQ(pull_node0->dim(2), 7);
+  ASSERT_EQ(pull_node0->dim(3), 13);
+
+  ASSERT_EQ(pull_node1->rank(), 1);
+  ASSERT_EQ(pull_node1->dim(0), 13);
+
+  ASSERT_TRUE(broadcast_node->mapping()->defined(0));
+  ASSERT_EQ(broadcast_node->mapping()->dim(0), 5);
+  ASSERT_TRUE(broadcast_node->mapping()->defined(1));
+  ASSERT_EQ(broadcast_node->mapping()->dim(1), 6);
+  ASSERT_TRUE(broadcast_node->mapping()->defined(2));
+  ASSERT_EQ(broadcast_node->mapping()->dim(2), 7);
 }
 
 TEST_F(TestTransformer_mir2loco, DepthwiseConv2D_Test)