#include "mir/ops/Conv2DOp.h"
#include "mir/ops/Deconv2DOp.h"
#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/DivOp.h"
#include "mir/ops/FullyConnectedOp.h"
#include "mir/ops/MaxPool2DOp.h"
#include "mir/ops/MulOp.h"
#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
#include "mir/ops/SoftmaxOp.h"
+#include "mir/ops/SubOp.h"
#include "mir/ShapeRange.h"
}
throw std::runtime_error("Unsupported data type");
}
+
+loco::Node *createBroadcastIfNeeded(loco::Node *node, const mir::Shape &shape,
+ const mir::Shape &out_shape)
+{
+ auto graph = node->graph();
+
+ if (shape == out_shape)
+ return node; // not needed
+
+ int32_t out_rank = out_shape.rank();
+ int32_t rank_diff = out_rank - shape.rank();
+ // Create Broadcast
+ auto *broadcast = graph->nodes()->create<loco::TensorBroadcast>();
+ // Create Reshape for equal ranks
+ if (shape.rank() != out_rank)
+ {
+ auto *reshape = graph->nodes()->create<loco::FixedReshape>();
+ reshape->input(node);
+ reshape->rank(out_rank);
+ broadcast->input(reshape);
+ // Set reshape dims
+ for (int32_t dim = 0; dim < out_rank; dim++)
+ {
+ if (dim < rank_diff)
+ reshape->dim(dim) = 1;
+ else
+ reshape->dim(dim) = shape.dim(dim - rank_diff);
+ }
+ }
+ else
+ {
+ broadcast->input(node);
+ }
+ // Flag if no one dim isn't equal
+ bool compatible_shapes = true;
+ for (int32_t dim = 0; dim < out_rank; dim++)
+ {
+ // Set broadcast mapping
+ if (dim < rank_diff || (shape.dim(dim - rank_diff) == 1 && out_shape.dim(dim) != 1))
+ broadcast->mapping()->dim(dim) = out_shape.dim(dim);
+ // Check compatibility
+ if (dim >= rank_diff && shape.dim(dim - rank_diff) != 1 &&
+ shape.dim(dim - rank_diff) != out_shape.dim(dim))
+ compatible_shapes = false;
+ }
+ // Check compatibility
+ if (!compatible_shapes)
+ throw std::runtime_error("Not compatible shapes for broadcasting!");
+
+ return broadcast;
+}
+
+template <typename NodeType>
+NodeType *createEltwiseBinary(const mir::ops::BinaryElementwiseOp &op, loco::Node *lhs,
+ loco::Node *rhs)
+{
+ auto graph = lhs->graph();
+
+ const auto &lhs_shape = op.getInput(0)->getProducer()->getShape();
+ const auto &rhs_shape = op.getInput(1)->getProducer()->getShape();
+ const auto &out_shape = op.getOutputShape(0);
+ // Create Broadcast if it's needed
+ auto lhs_node = createBroadcastIfNeeded(lhs, lhs_shape, out_shape);
+ auto rhs_node = createBroadcastIfNeeded(rhs, rhs_shape, out_shape);
+ // Create Node
+ auto result = graph->nodes()->create<NodeType>();
+ result->lhs(lhs_node);
+ result->rhs(rhs_node);
+ return result;
+}
} // namespace
void Transformer::visit(mir::ops::AddOp &op)
// Get Input
auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
-
- auto result = _loco_graph->nodes()->create<loco::EltwiseAdd>();
- result->lhs(lhs);
- result->rhs(rhs);
-
+ auto result = createEltwiseBinary<loco::EltwiseAdd>(op, lhs, rhs);
// Not set Shape
// Add to map
_mir2loco_map.emplace(&op, result);
_mir2loco_map.emplace(&op, feature_dec);
}
+void Transformer::visit(mir::ops::DivOp &op)
+{
+ // Get Input
+ auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
+ auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
+ auto result = createEltwiseBinary<loco::EltwiseDiv>(op, lhs, rhs);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, result);
+}
+
void Transformer::visit(mir::ops::FullyConnectedOp &op)
{
auto input = op.getInput(0)->getProducer()->getNode();
// Get Input
auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
-
- auto result = _loco_graph->nodes()->create<loco::EltwiseMul>();
- result->lhs(lhs);
- result->rhs(rhs);
-
+ auto result = createEltwiseBinary<loco::EltwiseMul>(op, lhs, rhs);
// Not set Shape
// Add to map
_mir2loco_map.emplace(&op, result);
_mir2loco_map.emplace(&op, softmax_node);
}
+void Transformer::visit(mir::ops::SubOp &op)
+{
+ // Get Input
+ auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
+ auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
+ auto result = createEltwiseBinary<loco::EltwiseSub>(op, lhs, rhs);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, result);
+}
+
void Transformer::visit_fallback(mir::Operation &op) { throw std::runtime_error("NYI operation"); }
std::unique_ptr<loco::Graph> Transformer::transform(mir::Graph *mir_graph)
{
mir::Graph mir_graph;
- mir::Shape input_shape{5, 6, 7, 3};
-
- auto *input1 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
- auto *input2 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
+ auto *input1 = mir_graph.create<mir::ops::InputOp>(mir::Shape{5, 6, 7, 3})->getOutput(0);
+ auto *input2 = mir_graph.create<mir::ops::InputOp>(mir::Shape{5, 1, 7, 3})->getOutput(0);
auto *add = mir_graph.create<mir::ops::AddOp>(input1, input2)->getOutput(0);
mir_graph.create<mir::ops::OutputOp>(add);
input1->setName("x1");
mir2loco::Transformer transformer;
auto loco_graph = transformer.transform(&mir_graph);
- auto *pull1_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
- auto *pull2_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(1));
- auto *add_node = dynamic_cast<loco::EltwiseAdd *>(loco_graph->nodes()->at(2));
- auto *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(3));
-
- ASSERT_NE(pull1_node, nullptr);
- ASSERT_NE(pull2_node, nullptr);
+ // Pull
+ auto inputs = loco_graph->inputs();
+ ASSERT_EQ(inputs->size(), 2);
+ loco::Pull *pull_node0 = loco::pull_node(loco_graph.get(), 0);
+ ASSERT_NE(pull_node0, nullptr);
+ loco::Pull *pull_node1 = loco::pull_node(loco_graph.get(), 1);
+ ASSERT_NE(pull_node1, nullptr);
+ // Add
+ auto pull_uses = loco::succs(pull_node0);
+ ASSERT_EQ(pull_uses.size(), 1);
+ loco::EltwiseAdd *add_node = dynamic_cast<loco::EltwiseAdd *>(*pull_uses.begin());
ASSERT_NE(add_node, nullptr);
- ASSERT_NE(push_node, nullptr);
-
- ASSERT_EQ(add_node->lhs(), pull1_node);
- ASSERT_EQ(add_node->rhs(), pull2_node);
- ASSERT_EQ(push_node->from(), add_node);
-
- // Shape check
- ASSERT_EQ(pull1_node->rank(), 4);
- ASSERT_EQ(pull1_node->dim(0), 5);
- ASSERT_EQ(pull1_node->dim(1), 6);
- ASSERT_EQ(pull1_node->dim(2), 7);
- ASSERT_EQ(pull1_node->dim(3), 3);
-
- ASSERT_EQ(pull2_node->rank(), 4);
- ASSERT_EQ(pull2_node->dim(0), 5);
- ASSERT_EQ(pull2_node->dim(1), 6);
- ASSERT_EQ(pull2_node->dim(2), 7);
- ASSERT_EQ(pull2_node->dim(3), 3);
+ ASSERT_EQ(add_node->lhs(), pull_node0);
+ // TensorBroadcast
+ loco::TensorBroadcast *broadcast_node = dynamic_cast<loco::TensorBroadcast *>(add_node->rhs());
+ ASSERT_NE(broadcast_node, nullptr);
+ ASSERT_EQ(broadcast_node->input(), pull_node1);
+ // Check params
+ ASSERT_TRUE(broadcast_node->mapping()->defined(1));
+ ASSERT_EQ(broadcast_node->mapping()->dim(1), 6);
}
TEST_F(TestTransformer_mir2loco, Conv2D_Test)
{
mir::Graph mir_graph;
- mir::Shape input_shape{5, 6, 7, 3};
-
- auto *input1 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
- auto *input2 = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
+ auto *input1 = mir_graph.create<mir::ops::InputOp>(mir::Shape{5, 6, 7, 13})->getOutput(0);
+ auto *input2 = mir_graph.create<mir::ops::InputOp>(mir::Shape{13})->getOutput(0);
auto *add = mir_graph.create<mir::ops::MulOp>(input1, input2)->getOutput(0);
mir_graph.create<mir::ops::OutputOp>(add);
input1->setName("x1");
mir2loco::Transformer transformer;
auto loco_graph = transformer.transform(&mir_graph);
- auto *pull1_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
- auto *pull2_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(1));
- auto *add_node = dynamic_cast<loco::EltwiseMul *>(loco_graph->nodes()->at(2));
- auto *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(3));
-
- ASSERT_NE(pull1_node, nullptr);
- ASSERT_NE(pull2_node, nullptr);
- ASSERT_NE(add_node, nullptr);
- ASSERT_NE(push_node, nullptr);
-
- ASSERT_EQ(add_node->lhs(), pull1_node);
- ASSERT_EQ(add_node->rhs(), pull2_node);
- ASSERT_EQ(push_node->from(), add_node);
-
- // Shape check
- ASSERT_EQ(pull1_node->rank(), 4);
- ASSERT_EQ(pull1_node->dim(0), 5);
- ASSERT_EQ(pull1_node->dim(1), 6);
- ASSERT_EQ(pull1_node->dim(2), 7);
- ASSERT_EQ(pull1_node->dim(3), 3);
-
- ASSERT_EQ(pull2_node->rank(), 4);
- ASSERT_EQ(pull2_node->dim(0), 5);
- ASSERT_EQ(pull2_node->dim(1), 6);
- ASSERT_EQ(pull2_node->dim(2), 7);
- ASSERT_EQ(pull2_node->dim(3), 3);
+ // Pulls
+ auto inputs = loco_graph->inputs();
+ ASSERT_EQ(inputs->size(), 2);
+ loco::Pull *pull_node0 = loco::pull_node(loco_graph.get(), 0);
+ ASSERT_NE(pull_node0, nullptr);
+ loco::Pull *pull_node1 = loco::pull_node(loco_graph.get(), 1);
+ ASSERT_NE(pull_node1, nullptr);
+ // Mul
+ auto pull0_uses = loco::succs(pull_node0);
+ ASSERT_EQ(pull0_uses.size(), 1);
+ loco::EltwiseMul *mul_node = dynamic_cast<loco::EltwiseMul *>(*pull0_uses.begin());
+ ASSERT_NE(mul_node, nullptr);
+ // Broadcast
+ loco::TensorBroadcast *broadcast_node = dynamic_cast<loco::TensorBroadcast *>(mul_node->rhs());
+ ASSERT_NE(broadcast_node, nullptr);
+ ASSERT_EQ(mul_node->lhs(), pull_node0);
+ ASSERT_EQ(mul_node->rhs(), broadcast_node);
+ loco::FixedReshape *reshape_node = dynamic_cast<loco::FixedReshape *>(broadcast_node->input());
+ ASSERT_NE(reshape_node, nullptr);
+ ASSERT_EQ(reshape_node->input(), pull_node1);
+ ASSERT_EQ(reshape_node->rank(), 4);
+ ASSERT_EQ(reshape_node->dim(0), 1);
+ ASSERT_EQ(reshape_node->dim(1), 1);
+ ASSERT_EQ(reshape_node->dim(2), 1);
+ ASSERT_EQ(reshape_node->dim(3), 13);
+ // Params checks
+ ASSERT_EQ(pull_node0->rank(), 4);
+ ASSERT_EQ(pull_node0->dim(0), 5);
+ ASSERT_EQ(pull_node0->dim(1), 6);
+ ASSERT_EQ(pull_node0->dim(2), 7);
+ ASSERT_EQ(pull_node0->dim(3), 13);
+
+ ASSERT_EQ(pull_node1->rank(), 1);
+ ASSERT_EQ(pull_node1->dim(0), 13);
+
+ ASSERT_TRUE(broadcast_node->mapping()->defined(0));
+ ASSERT_EQ(broadcast_node->mapping()->dim(0), 5);
+ ASSERT_TRUE(broadcast_node->mapping()->defined(1));
+ ASSERT_EQ(broadcast_node->mapping()->dim(1), 6);
+ ASSERT_TRUE(broadcast_node->mapping()->defined(2));
+ ASSERT_EQ(broadcast_node->mapping()->dim(2), 7);
}
TEST_F(TestTransformer_mir2loco, DepthwiseConv2D_Test)