{
namespace pass
{
-
-#define UNUSED(x) (void)x
-
void PermutationEliminationPass::callback(const operand::Index &inp_index, operand::Object &object)
{
if (_graph.getInputs().contains(inp_index))
void PermutationEliminationPass::eliminateInput(const operand::Index &inp_index,
operand::Object &object)
{
- UNUSED(inp_index;);
- UNUSED(object);
- /* TODO */
+ auto &model_inputs = _graph.getInputs();
+
+ // get uses of the model's given input
+ auto uses = object.getUses();
+
+ // input must be used just by permutation
+ if (uses.size() != 1)
+ {
+ return;
+ }
+
+ for (auto input_use : uses.list())
+ {
+ auto &perm_operation = _graph.operations().at(input_use);
+ auto perm_inputs = perm_operation.getInputs();
+
+ auto perm_outputs = perm_operation.getOutputs();
+
+ if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, true))
+ {
+ return;
+ }
+
+ assert(perm_inputs.at(0) == inp_index);
+
+ VERBOSE(PermutationEliminationPass::EliminateInput) << "remove NHWC_TO_NCHW permutation\n";
+
+ // set model's new input, which was output of permutation
+ model_inputs.replace(inp_index, perm_outputs.at(0));
+
+ // remove model's input, which is also input of permutation
+ _graph.removeOperand(inp_index);
+
+ // remove permutation operation
+ _graph.operations().remove(input_use);
+
+ VERBOSE(PermutationEliminationPass::EliminateInput)
+ << inp_index.value() << " is model's input and is removed. New input is "
+ << perm_outputs.at(0).value() << "\n"
+ << input_use.value() << " is removed permutation operation\n";
+ }
}
void PermutationEliminationPass::eliminateOutput(const operand::Index &out_index,
operand::Object &object)
{
- UNUSED(out_index);
- UNUSED(object);
- /* TODO */
+ auto &model_outputs = _graph.getOutputs();
+
+ // get defs of the model's given output
+ auto defs = object.getDef();
+
+ // output must use just permutation
+ if (defs.size() != 1)
+ {
+ return;
+ }
+
+ for (auto output_def : defs.list())
+ {
+ auto &perm_operation = _graph.operations().at(output_def);
+ auto perm_outputs = perm_operation.getOutputs();
+
+ auto perm_inputs = perm_operation.getInputs();
+ if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, false))
+ {
+ return;
+ }
+
+ assert(perm_outputs.at(0) == out_index);
+
+ VERBOSE(PermutationEliminationPass::EliminateOutput) << "remove NCHW_TO_NHWC permutation\n";
+
+ // Update operations' output that is used by permute operand
+ for (auto perm_input_index : perm_inputs)
+ {
+ auto &perm_input_operand = _graph.operands().at(perm_input_index);
+ perm_input_operand.removeUse(output_def);
+ }
+
+ // set model's new output, which was input of permutation
+ model_outputs.replace(out_index, perm_inputs.at(0));
+
+ // remove model's output, which is also output of permutation
+ _graph.removeOperand(out_index);
+
+ // remove permutation operation
+ _graph.operations().remove(output_def);
+
+ VERBOSE(PermutationEliminationPass::EliminateOutput)
+ << out_index.value() << " is model's output and is removed. New output is "
+ << perm_inputs.at(0).value() << "\n"
+ << output_def.value() << " is removed permutation operation\n";
+ }
}
bool PermutationEliminationPass::isPermuteLayerToEliminate(const operand::IndexSet &inp_indexes,
const operand::IndexSet &out_indexes,
bool is_for_model_input)
{
- UNUSED(inp_indexes);
- UNUSED(out_indexes);
- UNUSED(is_for_model_input);
- /* TODO */
- return true;
+ auto input_def_backends = _graph.operands().at(inp_indexes.at(0)).lower_info()->def_backends();
+ auto output_def_backends = _graph.operands().at(out_indexes.at(0)).lower_info()->def_backends();
+
+ auto input_layout = input_def_backends.getOnlyElement()->config()->getOperandLayout();
+ auto output_layout = output_def_backends.getOnlyElement()->config()->getOperandLayout();
+
+ if (input_def_backends.size() != 1 || output_def_backends.size() != 1)
+ {
+ return false;
+ }
+
+ // all operands' backend must be the same
+ for (auto index : inp_indexes)
+ {
+ auto op_backend_set = _graph.operands().at(index).lower_info()->def_backends();
+ if (op_backend_set.size() != 1 ||
+ input_layout != op_backend_set.getOnlyElement()->config()->getOperandLayout())
+ {
+ return false;
+ }
+ }
+ // all operands' backend must be the same
+ for (auto index : out_indexes)
+ {
+ auto op_backend_set = _graph.operands().at(index).lower_info()->def_backends();
+ if (op_backend_set.size() != 1 ||
+ output_layout != op_backend_set.getOnlyElement()->config()->getOperandLayout())
+ {
+ return false;
+ }
+ }
+
+ if (is_for_model_input)
+ {
+ // check if this is NHWC_TO_NCHW permutation: must have single input, which is model's input
+ return (inp_indexes.size() == 1 && input_layout == graph::operand::Layout::NHWC &&
+ output_layout == graph::operand::Layout::NCHW);
+ }
+
+ // check if this is NCHW_TO_NHWC permutation: must have single output, which is model's output
+ return (out_indexes.size() == 1 && input_layout == graph::operand::Layout::NCHW &&
+ output_layout == graph::operand::Layout::NHWC);
}
} // namespace pass