#include "mir2loco.h"
#include "mir/ops/AddOp.h"
+#include "mir/ops/AvgPool2DOp.h"
#include "mir/ops/ConcatOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/Conv2DOp.h"
#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/MaxPool2DOp.h"
#include "mir/ops/MulOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
return std::move(res);
}
-void setupPad(const std::vector<int32_t> &padding_before, const std::vector<int32_t> &padding_after,
- loco::Pad<2> *pad)
+void setupPad(const std::vector<std::int32_t> &padding_before,
+ const std::vector<std::int32_t> &padding_after, loco::Pad<2> *pad)
{
- if (padding_before.size() != 2 || padding_after.size() != 2)
- throw std::runtime_error("Support only 2D paddings!");
-
+ assert(padding_before.size() == 2 && padding_after.size() == 2);
pad->left(padding_before.at(0));
pad->top(padding_before.at(1));
pad->right(padding_after.at(0));
void setupWindow(const mir::Shape &window_shape, loco::Window<2> *window)
{
- if (window_shape.rank() != 2)
- throw std::runtime_error("Support only 2D window size!");
-
+ assert(window_shape.rank() == 2);
window->horizontal(window_shape.dim(0));
window->vertical(window_shape.dim(1));
}
-void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride)
+void setupWindow(const std::vector<std::int32_t> &window_size, loco::Window<2> *window)
{
- if (stride_shape.rank() != 2)
- throw std::runtime_error("Support only 2D strides!");
+ assert(window_size.size() == 2);
+ window->horizontal(window_size[0]);
+ window->vertical(window_size[1]);
+}
+void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride)
+{
+ assert(stride_shape.rank() == 2);
stride->horizontal(stride_shape.dim(0));
stride->vertical(stride_shape.dim(1));
}
+void setupStride(const std::vector<std::int32_t> &strides, loco::Stride<2> *stride)
+{
+ assert(strides.size() == 2);
+ stride->horizontal(strides[0]);
+ stride->vertical(strides[1]);
+}
+
loco::FeatureEncode *createNHWCFeatureEncode(loco::Graph *graph)
{
auto encode_node = graph->nodes()->create<loco::FeatureEncode>();
_mir2loco_map.emplace(&op, result);
}
+void Transformer::visit(mir::ops::AvgPool2DOp &op)
+{
+ // Get Input
+ auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+ assert(loco_it != _mir2loco_map.end()); // can't find the input
+ // FeatureEncode
+ auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+ encode_node->input(loco_it->second);
+
+ auto avg_pool_node = _loco_graph->nodes()->create<loco::AvgPool2D>();
+ // Set Input
+ avg_pool_node->ifm(encode_node);
+ // Set convention (like tensorflow)
+ avg_pool_node->convention(op.getIncludePad() ? loco::AvgPool2D::Convention::Full
+ : loco::AvgPool2D::Convention::Valid);
+
+ setupWindow(op.getWindowSize(), avg_pool_node->window());
+ setupStride(op.getStrides(), avg_pool_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), avg_pool_node->pad());
+
+ // FeatureDecode
+ auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+ // Set Input
+ decode_node->input(avg_pool_node);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, decode_node);
+}
+
void Transformer::visit(mir::ops::ConcatOp &op)
{
if (op.getNumInputs() < 2)
_mir2loco_map.emplace(&op, pull_node);
}
+void Transformer::visit(mir::ops::MaxPool2DOp &op)
+{
+ // Get Input
+ auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+ assert(loco_it != _mir2loco_map.end()); // can't find the input
+ // FeatureEncode
+ auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+ encode_node->input(loco_it->second);
+
+ auto max_pool_node = _loco_graph->nodes()->create<loco::MaxPool2D>();
+ // Set Input
+ max_pool_node->ifm(encode_node);
+
+ setupWindow(op.getWindowSize(), max_pool_node->window());
+ setupStride(op.getStrides(), max_pool_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), max_pool_node->pad());
+
+ // FeatureDecode
+ auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+ // Set Input
+ decode_node->input(max_pool_node);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, decode_node);
+}
+
void Transformer::visit(mir::ops::MulOp &op)
{
// Get Input
#include "mir2loco.h"
#include "mir/ops/AddOp.h"
+#include "mir/ops/AvgPool2DOp.h"
#include "mir/ops/ConcatOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/Conv2DOp.h"
#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/MaxPool2DOp.h"
#include "mir/ops/MulOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
auto *input = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
auto *pool = mir_graph
- .create<mir::ops::PoolOp>(
- input, mir::ops::PoolOp::PoolingType::AVG, mir::Shape{2, 3},
- mir::Shape{4, 5}, std::vector<std::int32_t>{5, 9},
- std::vector<std::int32_t>{7, 4}, mir::ops::PoolOp::BorderType::EMPTY)
+ .create<mir::ops::AvgPool2DOp>(
+ input, std::vector<std::int32_t>{2, 3}, std::vector<std::int32_t>{4, 5},
+ std::vector<std::int32_t>{5, 9}, std::vector<std::int32_t>{7, 4}, true,
+ mir::DataFormat::NHWC)
->getOutput(0);
mir_graph.create<mir::ops::OutputOp>(pool);
input->setName("x");
ASSERT_EQ(decode_node->input(), pool_node);
ASSERT_EQ(push_node->from(), decode_node);
// Check params
- ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Valid);
+ ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Full);
ASSERT_EQ(pool_node->pad()->left(), 5);
ASSERT_EQ(pool_node->pad()->top(), 9);
ASSERT_EQ(pool_node->pad()->right(), 7);
mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
auto *input = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
auto *pool = mir_graph
- .create<mir::ops::PoolOp>(
- input, mir::ops::PoolOp::PoolingType::MAX, mir::Shape{2, 3},
- mir::Shape{4, 5}, std::vector<std::int32_t>{5, 9},
- std::vector<std::int32_t>{7, 4}, mir::ops::PoolOp::BorderType::EMPTY)
+ .create<mir::ops::MaxPool2DOp>(
+ input, std::vector<std::int32_t>{2, 3}, std::vector<std::int32_t>{4, 5},
+ std::vector<std::int32_t>{5, 9}, std::vector<std::int32_t>{7, 4},
+ mir::DataFormat::NHWC)
->getOutput(0);
mir_graph.create<mir::ops::OutputOp>(pool);
input->setName("x");