#include "mir2loco.h"
+#include "mir/ops/PoolOp.h"
+
+#include <stdex/Memory.h>
+
namespace mir2loco
{
-
+namespace
+{
template <class NodeType> void setupShape(const mir::Shape &shape, NodeType *node)
{
node->rank(shape.rank());
}
}
+void setupPad(const std::vector<int32_t> &padding_before, const std::vector<int32_t> &padding_after,
+ loco::Pad<2> *pad)
+{
+ if (padding_before.size() != 2 || padding_after.size() != 2)
+ throw std::runtime_error("Support only 2D paddings!");
+
+ pad->left(padding_before.at(0));
+ pad->top(padding_before.at(1));
+ pad->right(padding_after.at(0));
+ pad->bottom(padding_after.at(1));
+}
+
+void setupWindow(const mir::Shape &window_shape, loco::Window<2> *window)
+{
+ if (window_shape.rank() != 2)
+ throw std::runtime_error("Support only 2D window size!");
+
+ window->horizontal(window_shape.dim(0));
+ window->vertical(window_shape.dim(1));
+}
+
+void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride)
+{
+ if (stride_shape.rank() != 2)
+ throw std::runtime_error("Support only 2D strides!");
+
+ stride->horizontal(stride_shape.dim(0));
+ stride->vertical(stride_shape.dim(1));
+}
+
+loco::FeatureEncode *createNHWCFeatureEncode(loco::Graph *graph)
+{
+ auto encode_node = graph->nodes()->create<loco::FeatureEncode>();
+
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+
+ encode_node->encoder(std::move(enc));
+ return encode_node;
+}
+
+loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph)
+{
+ auto decode_node = graph->nodes()->create<loco::FeatureDecode>();
+
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+
+ decode_node->decoder(std::move(dec));
+ return decode_node;
+}
+} // namespace
+
void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); }
void Transformer::visit(mir::ops::PadOp &op) { throw std::runtime_error("NYI"); }
-void Transformer::visit(mir::ops::PoolOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::PoolOp &op)
+{
+ loco::Node *pool_node;
+ // Get Input
+ auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+ assert(loco_it != _mir2loco_map.end()); // can't find the input
+ // FeatureEncode
+ auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+ encode_node->input(loco_it->second);
+
+ switch (op.getPoolingType())
+ {
+ case mir::ops::PoolOp::PoolingType::AVG:
+ {
+ auto avg_pool_node = _loco_graph->nodes()->create<loco::AvgPool2D>();
+ // Set Input
+ avg_pool_node->ifm(encode_node);
+ // Set convention (like tensorflow)
+ avg_pool_node->convention(loco::AvgPool2D::Convention::Valid);
+
+ setupWindow(op.getWindowShape(), avg_pool_node->window());
+ setupStride(op.getStrides(), avg_pool_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), avg_pool_node->pad());
+
+ pool_node = avg_pool_node;
+ break;
+ }
+ case mir::ops::PoolOp::PoolingType::MAX:
+ {
+ auto max_pool_node = _loco_graph->nodes()->create<loco::MaxPool2D>();
+ // Set Input
+ max_pool_node->ifm(encode_node);
+
+ setupWindow(op.getWindowShape(), max_pool_node->window());
+ setupStride(op.getStrides(), max_pool_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), max_pool_node->pad());
+
+ pool_node = max_pool_node;
+ break;
+ }
+ case mir::ops::PoolOp::PoolingType::MIN:
+ throw std::runtime_error("Min pooling not supported!");
+ default:
+ assert(false && "Unknown pooling type!");
+ }
+ // FeatureDecode
+ auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+ // Set Input
+ decode_node->input(pool_node);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, decode_node);
+}
void Transformer::visit(mir::ops::ReduceOp &op) { throw std::runtime_error("NYI"); }
#include "mir2loco.h"
+#include "mir/ops/PoolOp.h"
+
#include <gtest/gtest.h>
class TestTransformer_mir2loco : public ::testing::Test
ASSERT_EQ(push_node->dim(2), 7);
ASSERT_EQ(push_node->dim(3), 8);
}
+
+TEST_F(TestTransformer_mir2loco, Avg_Pool_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
+ auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+ auto *pool = mir_graph.create<mir::ops::PoolOp>(
+ "pool", input->getOutput(0), mir::ops::PoolOp::PoolingType::AVG, mir::Shape{2, 3},
+ mir::Shape{4, 5}, std::vector<int32_t>{5, 9}, std::vector<int32_t>{7, 4},
+ mir::ops::PoolOp::BorderType::EMPTY);
+ auto *output = mir_graph.create<mir::ops::OutputOp>("output", pool->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+ loco::FeatureEncode *encode_node =
+ dynamic_cast<loco::FeatureEncode *>(loco_graph->nodes()->at(1));
+ loco::AvgPool2D *pool_node = dynamic_cast<loco::AvgPool2D *>(loco_graph->nodes()->at(2));
+ loco::FeatureDecode *decode_node =
+ dynamic_cast<loco::FeatureDecode *>(loco_graph->nodes()->at(3));
+ loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(4));
+
+ ASSERT_NE(pull_node, nullptr);
+ ASSERT_NE(encode_node, nullptr);
+ ASSERT_NE(pool_node, nullptr);
+ ASSERT_NE(decode_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(encode_node->input(), pull_node);
+ ASSERT_EQ(pool_node->ifm(), encode_node);
+ ASSERT_EQ(decode_node->input(), pool_node);
+ ASSERT_EQ(push_node->from(), decode_node);
+ // Check params
+ ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Valid);
+ ASSERT_EQ(pool_node->pad()->left(), 5);
+ ASSERT_EQ(pool_node->pad()->top(), 9);
+ ASSERT_EQ(pool_node->pad()->right(), 7);
+ ASSERT_EQ(pool_node->pad()->bottom(), 4);
+ ASSERT_EQ(pool_node->window()->horizontal(), 2);
+ ASSERT_EQ(pool_node->window()->vertical(), 3);
+ ASSERT_EQ(pool_node->stride()->horizontal(), 4);
+ ASSERT_EQ(pool_node->stride()->vertical(), 5);
+}
+
+TEST_F(TestTransformer_mir2loco, Max_Pool_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
+ auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+ auto *pool = mir_graph.create<mir::ops::PoolOp>(
+ "pool", input->getOutput(0), mir::ops::PoolOp::PoolingType::MAX, mir::Shape{2, 3},
+ mir::Shape{4, 5}, std::vector<int32_t>{5, 9}, std::vector<int32_t>{7, 4},
+ mir::ops::PoolOp::BorderType::EMPTY);
+ auto *output = mir_graph.create<mir::ops::OutputOp>("output", pool->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+ loco::FeatureEncode *encode_node =
+ dynamic_cast<loco::FeatureEncode *>(loco_graph->nodes()->at(1));
+ loco::MaxPool2D *pool_node = dynamic_cast<loco::MaxPool2D *>(loco_graph->nodes()->at(2));
+ loco::FeatureDecode *decode_node =
+ dynamic_cast<loco::FeatureDecode *>(loco_graph->nodes()->at(3));
+ loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(4));
+
+ ASSERT_NE(pull_node, nullptr);
+ ASSERT_NE(encode_node, nullptr);
+ ASSERT_NE(pool_node, nullptr);
+ ASSERT_NE(decode_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(encode_node->input(), pull_node);
+ ASSERT_EQ(pool_node->ifm(), encode_node);
+ ASSERT_EQ(decode_node->input(), pool_node);
+ ASSERT_EQ(push_node->from(), decode_node);
+ // Check params
+ ASSERT_EQ(pool_node->pad()->left(), 5);
+ ASSERT_EQ(pool_node->pad()->top(), 9);
+ ASSERT_EQ(pool_node->pad()->right(), 7);
+ ASSERT_EQ(pool_node->pad()->bottom(), 4);
+ ASSERT_EQ(pool_node->window()->horizontal(), 2);
+ ASSERT_EQ(pool_node->window()->vertical(), 3);
+ ASSERT_EQ(pool_node->stride()->horizontal(), 4);
+ ASSERT_EQ(pool_node->stride()->vertical(), 5);
+}