mir::Operation::Output *DataFormatSwitcher::insertTransposeBefore(mir::Operation::Output *out)
{
+ mir::Operation::Output *new_out;
if (_target_format == mir::DataFormat::NHWC)
- return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
- ->getOutput(0); // NCHW -> NHWC
+ new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
+ ->getOutput(0); // NCHW -> NHWC
else
- return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
- ->getOutput(0); // NHWC -> NCHW
+ new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
+ ->getOutput(0); // NHWC -> NCHW
+ if (out->getType().isQuantized())
+ new_out->setQuantization(out->getType().getQuantization());
+ return new_out;
}
mir::Operation::Output *DataFormatSwitcher::insertTransposeAfter(mir::Operation::Output *out)
{
+ mir::Operation::Output *new_out;
if (_target_format == mir::DataFormat::NHWC)
- return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
- ->getOutput(0); // NHWC -> NCHW
+ new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
+ ->getOutput(0); // NHWC -> NCHW
else
- return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
- ->getOutput(0); // NCHW -> NHWC
+ new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
+ ->getOutput(0); // NCHW -> NHWC
+ if (out->getType().isQuantized())
+ new_out->setQuantization(out->getType().getQuantization());
+ return new_out;
}
void DataFormatSwitcher::switchAvgPool2D(mir::ops::AvgPool2DOp *op)
if (op->getDataFormat() == _target_format)
return;
- assert(op->getNumInputs() == 2);
+ assert(op->getNumInputs() >= 2);
auto *input = op->getInput(0);
auto *kernel = op->getInput(1);
auto *trans_in = insertTransposeBefore(input);
- auto new_dw_conv = _graph->create<mir::ops::Conv2DOp>(trans_in, kernel, attributes);
+ mir::Operation *new_conv;
+ if (op->getNumInputs() == 2)
+ new_conv = _graph->create<mir::ops::Conv2DOp>(trans_in, kernel, attributes);
+ else
+ {
+ auto bias = op->getInput(2);
+ new_conv = _graph->create<mir::ops::Conv2DOp>(trans_in, kernel, bias, attributes);
+ }
+
+ if (op->getOutput(0)->getType().isQuantized())
+ new_conv->getOutput(0)->setQuantization(op->getOutput(0)->getType().getQuantization());
- auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0));
+ auto *trans_out = insertTransposeAfter(new_conv->getOutput(0));
_graph->replaceNode(op, trans_out->getNode());
}
if (op->getDataFormat() == _target_format)
return;
- assert(op->getNumInputs() == 2);
+ assert(op->getNumInputs() >= 2);
auto *input = op->getInput(0);
auto *kernel = op->getInput(1);
auto *trans_in = insertTransposeBefore(input);
- auto new_dw_conv = _graph->create<mir::ops::DepthwiseConv2DOp>(trans_in, kernel, attributes);
+ mir::Operation *new_dw_conv;
+ if (op->getNumInputs() == 2)
+ new_dw_conv = _graph->create<mir::ops::DepthwiseConv2DOp>(trans_in, kernel, attributes);
+ else
+ {
+ auto bias = op->getInput(2);
+ new_dw_conv = _graph->create<mir::ops::DepthwiseConv2DOp>(trans_in, kernel, bias, attributes);
+ }
+
+ if (op->getOutput(0)->getType().isQuantized())
+ new_dw_conv->getOutput(0)->setQuantization(op->getOutput(0)->getType().getQuantization());
auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0));