From 6eb841d8ca500c807f79e626e60be43e77d24559 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 26 Jul 2019 14:25:42 +0300 Subject: [PATCH] [mir2loco] Support Pooling operations (#5846) * Support Pooling operation and tests for 2 ops Signed-off-by: Pavel Iliutchenko --- compiler/mir2loco/CMakeLists.txt | 1 + compiler/mir2loco/include/mir2loco.h | 2 +- compiler/mir2loco/src/mir2loco.cpp | 122 +++++++++++++++++++++++++++++++- compiler/mir2loco/src/mir2loco.test.cpp | 89 +++++++++++++++++++++++ 4 files changed, 211 insertions(+), 3 deletions(-) diff --git a/compiler/mir2loco/CMakeLists.txt b/compiler/mir2loco/CMakeLists.txt index 9aaf723..b506559 100644 --- a/compiler/mir2loco/CMakeLists.txt +++ b/compiler/mir2loco/CMakeLists.txt @@ -7,6 +7,7 @@ target_include_directories(mir2loco PRIVATE src) target_include_directories(mir2loco PUBLIC include) target_link_libraries(mir2loco PUBLIC mir) target_link_libraries(mir2loco PUBLIC loco) +target_link_libraries(mir2loco PRIVATE stdex) nncc_find_package(GTest QUIET) diff --git a/compiler/mir2loco/include/mir2loco.h b/compiler/mir2loco/include/mir2loco.h index aeda08c..795c424 100644 --- a/compiler/mir2loco/include/mir2loco.h +++ b/compiler/mir2loco/include/mir2loco.h @@ -20,7 +20,7 @@ namespace mir2loco { -class Transformer : public mir::Visitor +class Transformer final : public mir::Visitor { public: Transformer() = default; diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index 07ca57c..9ae6a92 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -16,9 +16,14 @@ #include "mir2loco.h" +#include "mir/ops/PoolOp.h" + +#include + namespace mir2loco { - +namespace +{ template void setupShape(const mir::Shape &shape, NodeType *node) { node->rank(shape.rank()); @@ -28,6 +33,67 @@ template void setupShape(const mir::Shape &shape, NodeType *nod } } +void setupPad(const std::vector &padding_before, const std::vector &padding_after, + loco::Pad<2> *pad) +{ + if (padding_before.size() != 2 || padding_after.size() != 2) + throw std::runtime_error("Support only 2D paddings!"); + + pad->left(padding_before.at(0)); + pad->top(padding_before.at(1)); + pad->right(padding_after.at(0)); + pad->bottom(padding_after.at(1)); +} + +void setupWindow(const mir::Shape &window_shape, loco::Window<2> *window) +{ + if (window_shape.rank() != 2) + throw std::runtime_error("Support only 2D window size!"); + + window->horizontal(window_shape.dim(0)); + window->vertical(window_shape.dim(1)); +} + +void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride) +{ + if (stride_shape.rank() != 2) + throw std::runtime_error("Support only 2D strides!"); + + stride->horizontal(stride_shape.dim(0)); + stride->vertical(stride_shape.dim(1)); +} + +loco::FeatureEncode *createNHWCFeatureEncode(loco::Graph *graph) +{ + auto encode_node = graph->nodes()->create(); + + auto enc = stdex::make_unique>(); + + enc->perm()->axis(loco::FeatureAxis::Count) = 0; + enc->perm()->axis(loco::FeatureAxis::Height) = 1; + enc->perm()->axis(loco::FeatureAxis::Width) = 2; + enc->perm()->axis(loco::FeatureAxis::Depth) = 3; + + encode_node->encoder(std::move(enc)); + return encode_node; +} + +loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph) +{ + auto decode_node = graph->nodes()->create(); + + auto dec = stdex::make_unique>(); + + dec->perm()->axis(loco::FeatureAxis::Count) = 0; + dec->perm()->axis(loco::FeatureAxis::Height) = 1; + dec->perm()->axis(loco::FeatureAxis::Width) = 2; + dec->perm()->axis(loco::FeatureAxis::Depth) = 3; + + decode_node->decoder(std::move(dec)); + return decode_node; +} +} // namespace + void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); } void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); } @@ -97,7 +163,59 @@ void Transformer::visit(mir::ops::OutputOp &op) void Transformer::visit(mir::ops::PadOp &op) { throw std::runtime_error("NYI"); } -void Transformer::visit(mir::ops::PoolOp &op) { throw std::runtime_error("NYI"); } +void Transformer::visit(mir::ops::PoolOp &op) +{ + loco::Node *pool_node; + // Get Input + auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode()); + assert(loco_it != _mir2loco_map.end()); // can't find the input + // FeatureEncode + auto encode_node = createNHWCFeatureEncode(_loco_graph.get()); + encode_node->input(loco_it->second); + + switch (op.getPoolingType()) + { + case mir::ops::PoolOp::PoolingType::AVG: + { + auto avg_pool_node = _loco_graph->nodes()->create(); + // Set Input + avg_pool_node->ifm(encode_node); + // Set convention (like tensorflow) + avg_pool_node->convention(loco::AvgPool2D::Convention::Valid); + + setupWindow(op.getWindowShape(), avg_pool_node->window()); + setupStride(op.getStrides(), avg_pool_node->stride()); + setupPad(op.getPaddingBefore(), op.getPaddingAfter(), avg_pool_node->pad()); + + pool_node = avg_pool_node; + break; + } + case mir::ops::PoolOp::PoolingType::MAX: + { + auto max_pool_node = _loco_graph->nodes()->create(); + // Set Input + max_pool_node->ifm(encode_node); + + setupWindow(op.getWindowShape(), max_pool_node->window()); + setupStride(op.getStrides(), max_pool_node->stride()); + setupPad(op.getPaddingBefore(), op.getPaddingAfter(), max_pool_node->pad()); + + pool_node = max_pool_node; + break; + } + case mir::ops::PoolOp::PoolingType::MIN: + throw std::runtime_error("Min pooling not supported!"); + default: + assert(false && "Unknown pooling type!"); + } + // FeatureDecode + auto decode_node = createNHWCFeatureDecode(_loco_graph.get()); + // Set Input + decode_node->input(pool_node); + // Not set Shape + // Add to map + _mir2loco_map.emplace(&op, decode_node); +} void Transformer::visit(mir::ops::ReduceOp &op) { throw std::runtime_error("NYI"); } diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index c6e7263..a0485d2 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -16,6 +16,8 @@ #include "mir2loco.h" +#include "mir/ops/PoolOp.h" + #include class TestTransformer_mir2loco : public ::testing::Test @@ -53,3 +55,90 @@ TEST_F(TestTransformer_mir2loco, Input_Output_Test) ASSERT_EQ(push_node->dim(2), 7); ASSERT_EQ(push_node->dim(3), 8); } + +TEST_F(TestTransformer_mir2loco, Avg_Pool_Test) +{ + mir::Graph mir_graph; + + mir::Shape input_shape = mir::Shape({7, 7, 9, 9}); + auto *input = mir_graph.create("input", input_shape); + auto *pool = mir_graph.create( + "pool", input->getOutput(0), mir::ops::PoolOp::PoolingType::AVG, mir::Shape{2, 3}, + mir::Shape{4, 5}, std::vector{5, 9}, std::vector{7, 4}, + mir::ops::PoolOp::BorderType::EMPTY); + auto *output = mir_graph.create("output", pool->getOutput(0)); + + mir2loco::Transformer transformer; + auto loco_graph = transformer.transform(&mir_graph); + + loco::Pull *pull_node = dynamic_cast(loco_graph->nodes()->at(0)); + loco::FeatureEncode *encode_node = + dynamic_cast(loco_graph->nodes()->at(1)); + loco::AvgPool2D *pool_node = dynamic_cast(loco_graph->nodes()->at(2)); + loco::FeatureDecode *decode_node = + dynamic_cast(loco_graph->nodes()->at(3)); + loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(4)); + + ASSERT_NE(pull_node, nullptr); + ASSERT_NE(encode_node, nullptr); + ASSERT_NE(pool_node, nullptr); + ASSERT_NE(decode_node, nullptr); + ASSERT_NE(push_node, nullptr); + ASSERT_EQ(encode_node->input(), pull_node); + ASSERT_EQ(pool_node->ifm(), encode_node); + ASSERT_EQ(decode_node->input(), pool_node); + ASSERT_EQ(push_node->from(), decode_node); + // Check params + ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Valid); + ASSERT_EQ(pool_node->pad()->left(), 5); + ASSERT_EQ(pool_node->pad()->top(), 9); + ASSERT_EQ(pool_node->pad()->right(), 7); + ASSERT_EQ(pool_node->pad()->bottom(), 4); + ASSERT_EQ(pool_node->window()->horizontal(), 2); + ASSERT_EQ(pool_node->window()->vertical(), 3); + ASSERT_EQ(pool_node->stride()->horizontal(), 4); + ASSERT_EQ(pool_node->stride()->vertical(), 5); +} + +TEST_F(TestTransformer_mir2loco, Max_Pool_Test) +{ + mir::Graph mir_graph; + + mir::Shape input_shape = mir::Shape({7, 7, 9, 9}); + auto *input = mir_graph.create("input", input_shape); + auto *pool = mir_graph.create( + "pool", input->getOutput(0), mir::ops::PoolOp::PoolingType::MAX, mir::Shape{2, 3}, + mir::Shape{4, 5}, std::vector{5, 9}, std::vector{7, 4}, + mir::ops::PoolOp::BorderType::EMPTY); + auto *output = mir_graph.create("output", pool->getOutput(0)); + + mir2loco::Transformer transformer; + auto loco_graph = transformer.transform(&mir_graph); + + loco::Pull *pull_node = dynamic_cast(loco_graph->nodes()->at(0)); + loco::FeatureEncode *encode_node = + dynamic_cast(loco_graph->nodes()->at(1)); + loco::MaxPool2D *pool_node = dynamic_cast(loco_graph->nodes()->at(2)); + loco::FeatureDecode *decode_node = + dynamic_cast(loco_graph->nodes()->at(3)); + loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(4)); + + ASSERT_NE(pull_node, nullptr); + ASSERT_NE(encode_node, nullptr); + ASSERT_NE(pool_node, nullptr); + ASSERT_NE(decode_node, nullptr); + ASSERT_NE(push_node, nullptr); + ASSERT_EQ(encode_node->input(), pull_node); + ASSERT_EQ(pool_node->ifm(), encode_node); + ASSERT_EQ(decode_node->input(), pool_node); + ASSERT_EQ(push_node->from(), decode_node); + // Check params + ASSERT_EQ(pool_node->pad()->left(), 5); + ASSERT_EQ(pool_node->pad()->top(), 9); + ASSERT_EQ(pool_node->pad()->right(), 7); + ASSERT_EQ(pool_node->pad()->bottom(), 4); + ASSERT_EQ(pool_node->window()->horizontal(), 2); + ASSERT_EQ(pool_node->window()->vertical(), 3); + ASSERT_EQ(pool_node->stride()->horizontal(), 4); + ASSERT_EQ(pool_node->stride()->vertical(), 5); +} -- 2.7.4