From ae2bfc073871aac1524bb3ba2b8384d55b347f4b Mon Sep 17 00:00:00 2001 From: Sergei Barannikov/Engineer/AI Tools Lab /SRR/Samsung Electronics Date: Wed, 16 Oct 2019 13:53:24 +0300 Subject: [PATCH] [nnc] Switch to `getInputProducer` method (#8189) Replace calls to `getInput()->getProducer()` with calls to `getInputProducer()`. Signed-off-by: Sergei Barannikov --- .../acl_soft_backend/AclCppOpGenerator.cpp | 34 +++++++++--------- compiler/nnc/backends/interpreter/Interpreter.cpp | 5 ++- .../backends/interpreter/InterpreterBackend.cpp | 4 +-- .../nnc/backends/soft_backend/ModelAnalyzer.cpp | 8 ++--- .../nnc/passes/optimizations/CombineTransposes.cpp | 4 +-- .../nnc/passes/optimizations/FuseArithmeticOps.cpp | 8 ++--- .../nnc/passes/optimizations/OptimizationUtils.cpp | 6 ++-- compiler/nnc/passes/optimizations/SinkRelu.cpp | 2 +- .../nnc/passes/optimizations/SinkTranspose.cpp | 2 +- .../passes/transformations/DataFormatSwitcher.cpp | 16 ++++----- .../unittests/optimizations/CombineTransposes.cpp | 10 +++--- compiler/nnc/unittests/optimizations/SinkTest.cpp | 2 +- .../nnc/unittests/soft_backend/CPPOperations.cpp | 2 +- .../nnc/unittests/transformations/Switcher.cpp | 40 +++++++++++----------- 14 files changed, 71 insertions(+), 72 deletions(-) diff --git a/compiler/nnc/backends/acl_soft_backend/AclCppOpGenerator.cpp b/compiler/nnc/backends/acl_soft_backend/AclCppOpGenerator.cpp index e3e45db..c400976 100644 --- a/compiler/nnc/backends/acl_soft_backend/AclCppOpGenerator.cpp +++ b/compiler/nnc/backends/acl_soft_backend/AclCppOpGenerator.cpp @@ -102,8 +102,8 @@ void AclCppOpGenerator::visit(ops::ConcatOp &op) auto inputs_var = _constrBlock->var("std::vector", prefix + "_inputs"); auto inputs = inputs_var->use(); - for (const auto &ir_input : ir_inputs) - _constrBlock->call("push_back", {AF::ref(AF::id(tensorName(ir_input.getProducer())))}, inputs); + for (const Operation::Output *ir_input : ir_inputs) + _constrBlock->call("push_back", {AF::ref(AF::id(tensorName(ir_input)))}, inputs); auto layer = genLayer("arm_compute::CLConcatenateLayer", prefix, {inputs, AF::ref(out), AF::lit(axis_name)}); @@ -125,7 +125,7 @@ void AclCppOpGenerator::visit(ops::DepthwiseConv2DOp &op) void AclCppOpGenerator::visit(ops::SoftmaxOp &op) { assert(op.getNumInputs() == 1); - const auto *ir_input = op.getInput(0)->getProducer(); + const auto *ir_input = op.getInputProducer(0); const auto *ir_output = op.getOutput(0); auto in = AF::id(tensorName(ir_input)); @@ -262,8 +262,8 @@ void AclCppOpGenerator::visit(ops::MaxPool2DOp &op) void AclCppOpGenerator::visit(ops::FullyConnectedOp &op) { assert(op.getNumInputs() == 2); - const auto *ir_input = op.getInput(0)->getProducer(); - const auto *ir_weights = op.getInput(1)->getProducer(); + const auto *ir_input = op.getInputProducer(0); + const auto *ir_weights = op.getInputProducer(1); const auto *ir_output = op.getOutput(0); auto ir_weights_op = dynamic_cast(ir_weights->getNode()); @@ -350,7 +350,7 @@ void AclCppOpGenerator::visit(ops::ReluOp &op) { genActivation(op, "RELU"); } void AclCppOpGenerator::visit(ops::ReshapeOp &op) { assert(op.getNumInputs() == 1); - const auto *ir_input = op.getInput(0)->getProducer(); + const auto *ir_input = op.getInputProducer(0); const auto *ir_output = op.getOutput(0); // Get the id of the input tensor in the generated artifact. @@ -401,7 +401,7 @@ void AclCppOpGenerator::visit(ops::EluOp & /*op*/) void AclCppOpGenerator::visit(ops::PadOp &op) { assert(op.getNumInputs() == 1); - const auto *ir_input = op.getInput(0)->getProducer(); + const auto *ir_input = op.getInputProducer(0); const auto *ir_output = op.getOutput(0); // Get the id of the input tensor. @@ -438,7 +438,7 @@ template void AclCppOpGenerator::genPooling(Op &op, const std::string &pooling_type, bool exclude_padding) { assert(op.getNumInputs() == 1); - const auto *ir_input = op.getInput(0)->getProducer(); + const auto *ir_input = op.getInputProducer(0); const auto *ir_output = op.getOutput(0); string in_name = tensorName(ir_input); @@ -492,8 +492,8 @@ void AclCppOpGenerator::genPooling(Op &op, const std::string &pooling_type, bool template void AclCppOpGenerator::genConvolution(Op &op, const string &acl_func_name, const string &suffix) { - const auto *ir_input = op.getInput(0)->getProducer(); - const auto *ir_weights = op.getInput(1)->getProducer(); + const auto *ir_input = op.getInputProducer(0); + const auto *ir_weights = op.getInputProducer(1); const auto *ir_output = op.getOutput(0); auto ir_weights_op = dynamic_cast(ir_weights->getNode()); @@ -574,7 +574,7 @@ void AclCppOpGenerator::genActivation(const Operation &op, const std::string &ac float a, float b) { assert(op.getNumInputs() == 1); - const auto *ir_input = op.getInput(0)->getProducer(); + const auto *ir_input = op.getInputProducer(0); const auto *ir_output = op.getOutput(0); // Get the id of the input tensor. @@ -777,7 +777,7 @@ void AclCppOpGenerator::genNamed(Graph *graph) const auto *output_op = outputs[0]; auto f = _artifactClass->func(true, "arm_compute::CLTensor&", "getOutput"); auto b = f->getBlock(); - auto id = AF::id(tensorName(output_op->getInput(0)->getProducer())); + auto id = AF::id(tensorName(output_op->getInputProducer(0))); b->ret(id); } } @@ -919,7 +919,7 @@ void AclCppOpGenerator::genTranspose(const std::shared_ptr &inp void AclCppOpGenerator::visit(mir::ops::TransposeOp &op) { assert(op.getNumInputs() == 1); - const auto *ir_input = op.getInput(0)->getProducer(); + const auto *ir_input = op.getInputProducer(0); const auto *ir_output = op.getOutput(0); // Get the input node tensor id in the DOM. @@ -956,8 +956,8 @@ void AclCppOpGenerator::visit(mir::ops::OutputOp & /*op*/) void AclCppOpGenerator::visit(mir::ops::AddOp &op) { assert(op.getNumInputs() == 2); - const auto *ir_lhs = op.getInput(0)->getProducer(); - const auto *ir_rhs = op.getInput(1)->getProducer(); + const auto *ir_lhs = op.getInputProducer(0); + const auto *ir_rhs = op.getInputProducer(1); const auto *ir_output = op.getOutput(0); // Create the output tensor in the DOM and obtain its identifier. @@ -978,8 +978,8 @@ void AclCppOpGenerator::visit(mir::ops::MaxOp &) { throw AclCppException("NYI"); void AclCppOpGenerator::visit(mir::ops::MulOp &op) { assert(op.getNumInputs() == 2); - const auto *ir_lhs = op.getInput(0)->getProducer(); - const auto *ir_rhs = op.getInput(1)->getProducer(); + const auto *ir_lhs = op.getInputProducer(0); + const auto *ir_rhs = op.getInputProducer(1); const auto *ir_output = op.getOutput(0); // Create the output tensor in the DOM and obtain its identifier. diff --git a/compiler/nnc/backends/interpreter/Interpreter.cpp b/compiler/nnc/backends/interpreter/Interpreter.cpp index 04375ab..6551334 100644 --- a/compiler/nnc/backends/interpreter/Interpreter.cpp +++ b/compiler/nnc/backends/interpreter/Interpreter.cpp @@ -52,10 +52,9 @@ std::vector> NNInterpreter::getInputTensors(const Operation &op) { std::vector> tensors; - for (const auto &input : op.getInputs()) + for (const Operation::Output *input : op.getInputs()) { - const auto *producer = input.getProducer(); - tensors.emplace_back(_opResults.at(producer->getNode()).at(producer->getIndex())); + tensors.emplace_back(_opResults.at(input->getNode()).at(input->getIndex())); } return tensors; } diff --git a/compiler/nnc/backends/interpreter/InterpreterBackend.cpp b/compiler/nnc/backends/interpreter/InterpreterBackend.cpp index 4fc5a1a..9746e42 100644 --- a/compiler/nnc/backends/interpreter/InterpreterBackend.cpp +++ b/compiler/nnc/backends/interpreter/InterpreterBackend.cpp @@ -135,8 +135,8 @@ void InterpreterBackend::run(mir::Graph *graph) for (const auto *output_op : graph->getOutputs()) { - const auto &tensor = interpreter.getResult(output_op->getInput(0)->getProducer()); - const auto &output_name = output_op->getInput(0)->getProducer()->getName(); + const auto &tensor = interpreter.getResult(output_op->getInputProducer(0)); + const auto &output_name = output_op->getInputProducer(0)->getName(); #ifdef NNC_HDF5_SUPPORTED writeTensorToHDF5File(tensor, output_name, cli::artifactDir); diff --git a/compiler/nnc/backends/soft_backend/ModelAnalyzer.cpp b/compiler/nnc/backends/soft_backend/ModelAnalyzer.cpp index b4c025a..02f5fe7 100644 --- a/compiler/nnc/backends/soft_backend/ModelAnalyzer.cpp +++ b/compiler/nnc/backends/soft_backend/ModelAnalyzer.cpp @@ -54,7 +54,7 @@ void ModelAnalyzer::appendOperationToInference(Operation *op, const string &func } else if (op->getType() == Operation::Type::output) { - assert(!op->getInput(0)->getProducer()->getName().empty()); + assert(!op->getInputProducer(0)->getName().empty()); } else { @@ -69,10 +69,10 @@ void ModelAnalyzer::appendOperationToInference(Operation *op, const string &func // process operation inputs vector node_input_tensors; - for (const auto &input : op->getInputs()) + for (const Operation::Output *input : op->getInputs()) { - size_t idx = input.getProducer()->getIndex(); - const Operation *prev_op = input.getProducer()->getNode(); + size_t idx = input->getIndex(); + const Operation *prev_op = input->getNode(); assert(_opToDescr.find(prev_op) != _opToDescr.end()); auto call = dynamic_cast(_opToDescr[prev_op]); assert(call); diff --git a/compiler/nnc/passes/optimizations/CombineTransposes.cpp b/compiler/nnc/passes/optimizations/CombineTransposes.cpp index 12aab94..2015ac6 100644 --- a/compiler/nnc/passes/optimizations/CombineTransposes.cpp +++ b/compiler/nnc/passes/optimizations/CombineTransposes.cpp @@ -76,7 +76,7 @@ nnc::PassData nnc::CombineTransposes::run(nnc::PassData data) if (!isIdentityTranspose(combined_axis_order)) { - auto new_tr_op = g->create(top_transpose->getInput(0)->getProducer(), + auto new_tr_op = g->create(top_transpose->getInputProducer(0), combined_axis_order); g->replaceNode(bottom_transpose, new_tr_op); @@ -84,7 +84,7 @@ nnc::PassData nnc::CombineTransposes::run(nnc::PassData data) else { // Connect top input to all outputs of bottom - Operation *top = top_transpose->getInput(0)->getProducer()->getNode(); + Operation *top = top_transpose->getInputProducer(0)->getNode(); g->replaceNode(bottom_transpose, top); } deleted_nodes.emplace(bottom_transpose); diff --git a/compiler/nnc/passes/optimizations/FuseArithmeticOps.cpp b/compiler/nnc/passes/optimizations/FuseArithmeticOps.cpp index f601d05..1c7185e 100644 --- a/compiler/nnc/passes/optimizations/FuseArithmeticOps.cpp +++ b/compiler/nnc/passes/optimizations/FuseArithmeticOps.cpp @@ -45,11 +45,11 @@ using Edge = pair; * This function used to get 'ConstantOp' with weights of 'AddOp', 'MulOp' or 'Conv2DOp' * For each of these ops weights stored in second input node */ -ops::ConstantOp *getSecondInputAsConst(const Operation *op) +ops::ConstantOp *getSecondInputAsConst(Operation *op) { assert(op->getType() == OpType::add || op->getType() == OpType::mul || op->getType() == OpType::conv2D); - return dynamic_cast(op->getInput(1)->getProducer()->getNode()); + return dynamic_cast(op->getInputProducer(1)->getNode()); } // This function finds successive operations of given types, with ConstantOp as second input @@ -174,7 +174,7 @@ bool fuseSuccessiveOps(Graph *g) // Create new constant operation and copy first successive operation auto new_const_op = mergeConstantOps(g, const1_op, const2_op, edge.second->getType()); - auto first_op_input = edge.first->getInput(0)->getProducer(); + auto first_op_input = edge.first->getInputProducer(0); auto new_op = g->copyOpWithInputs(edge.first, {first_op_input, new_const_op->getOutput(0)}); // Replace second successive operation with new one and remove old nodes @@ -213,7 +213,7 @@ bool sinkAddThroughMul(Graph *g) assert(old_add_const_op && ols_mul_const_op); // Create new operations - auto old_add_input = old_add_op->getInput(0)->getProducer(); + auto old_add_input = old_add_op->getInputProducer(0); auto new_mul_op = g->copyOpWithInputs(old_mul_op, {old_add_input, ols_mul_const_op->getOutput(0)}); auto new_add_const_op = mergeConstantOps(g, old_add_const_op, ols_mul_const_op, OpType::mul); diff --git a/compiler/nnc/passes/optimizations/OptimizationUtils.cpp b/compiler/nnc/passes/optimizations/OptimizationUtils.cpp index 706ddca..46dc6b0 100644 --- a/compiler/nnc/passes/optimizations/OptimizationUtils.cpp +++ b/compiler/nnc/passes/optimizations/OptimizationUtils.cpp @@ -25,12 +25,12 @@ void swapAdjacent(mir::Graph *g, mir::Operation *top, mir::Operation *bottom) assert(top->getNumInputs() == bottom->getNumInputs() && top->getNumInputs() == 1 && top->getNumInputs() == top->getNumOutputs() && top->getNumInputs() == bottom->getNumOutputs() && "incompatible ops"); - auto &ins = top->getInputs(); + const auto &ins = top->getInputs(); std::vector prods; prods.reserve(top->getNumInputs()); - for (auto &in : ins) + for (mir::Operation::Output *in : ins) { - prods.emplace_back(in.getProducer()); + prods.emplace_back(in); } mir::Operation *new_bottom = g->copyOpWithInputs(bottom, prods); prods.clear(); diff --git a/compiler/nnc/passes/optimizations/SinkRelu.cpp b/compiler/nnc/passes/optimizations/SinkRelu.cpp index 09bdc02..bdad9bf 100644 --- a/compiler/nnc/passes/optimizations/SinkRelu.cpp +++ b/compiler/nnc/passes/optimizations/SinkRelu.cpp @@ -58,7 +58,7 @@ PassData SinkRelu::run(PassData data) pre_relu.reserve(relus.size()); for (auto *r : relus) { - pre_relu.emplace_back(r->getInput(0)->getProducer()); + pre_relu.emplace_back(r->getInputProducer(0)); } // create replacement nodes auto new_concat = g->create(pre_relu, concat->getAxis()); diff --git a/compiler/nnc/passes/optimizations/SinkTranspose.cpp b/compiler/nnc/passes/optimizations/SinkTranspose.cpp index b1c7a0d..73fcfa5 100644 --- a/compiler/nnc/passes/optimizations/SinkTranspose.cpp +++ b/compiler/nnc/passes/optimizations/SinkTranspose.cpp @@ -63,7 +63,7 @@ PassData SinkTranspose::run(PassData data) prev_trans.reserve(trs.size()); for (auto transpose : trs) { - prev_trans.emplace_back(transpose->getInput(0)->getProducer()); + prev_trans.emplace_back(transpose->getInputProducer(0)); } auto new_concat = g->create(prev_trans, axis_order[concat->getAxis()]); auto new_transpose = g->create(new_concat->getOutput(0), axis_order); diff --git a/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp b/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp index 1d88203..4943a09 100644 --- a/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp +++ b/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp @@ -109,7 +109,7 @@ void DataFormatSwitcher::switchAvgPool2D(mir::ops::AvgPool2DOp *op) if (op->getDataFormat() == _target_format) return; - auto *input = op->getInput(0)->getProducer(); + auto *input = op->getInputProducer(0); mir::AvgPool2DOpAttributes attributes(op->getAttributes()); attributes.data_format = _target_format; @@ -129,8 +129,8 @@ void DataFormatSwitcher::switchConv2D(mir::ops::Conv2DOp *op) return; assert(op->getNumInputs() == 2); - auto *input = op->getInput(0)->getProducer(); - auto *kernel = op->getInput(1)->getProducer(); + auto *input = op->getInputProducer(0); + auto *kernel = op->getInputProducer(1); mir::Conv2DOpAttributes attributes(op->getAttributes()); attributes.data_format = _target_format; @@ -150,8 +150,8 @@ void DataFormatSwitcher::switchDeConv2D(mir::ops::DeConv2DOp *op) return; assert(op->getNumInputs() == 2); - auto *input = op->getInput(0)->getProducer(); - auto *kernel = op->getInput(1)->getProducer(); + auto *input = op->getInputProducer(0); + auto *kernel = op->getInputProducer(1); auto *trans_in = insertTransposeBefore(input); @@ -183,8 +183,8 @@ void DataFormatSwitcher::switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op) return; assert(op->getNumInputs() == 2); - auto *input = op->getInput(0)->getProducer(); - auto *kernel = op->getInput(1)->getProducer(); + auto *input = op->getInputProducer(0); + auto *kernel = op->getInputProducer(1); mir::Conv2DOpAttributes attributes(op->getAttributes()); attributes.data_format = _target_format; @@ -203,7 +203,7 @@ void DataFormatSwitcher::switchMaxPool2D(mir::ops::MaxPool2DOp *op) if (op->getDataFormat() == _target_format) return; - auto *input = op->getInput(0)->getProducer(); + auto *input = op->getInputProducer(0); mir::MaxPool2DOpAttributes attributes(op->getAttributes()); attributes.data_format = _target_format; diff --git a/compiler/nnc/unittests/optimizations/CombineTransposes.cpp b/compiler/nnc/unittests/optimizations/CombineTransposes.cpp index c6b1464..08f6cb8 100644 --- a/compiler/nnc/unittests/optimizations/CombineTransposes.cpp +++ b/compiler/nnc/unittests/optimizations/CombineTransposes.cpp @@ -111,8 +111,8 @@ TEST(OptPass, combineTransposesBush) pass.run(&g); g.accept(&d); ASSERT_EQ("i_0.b_4.", ss.str()); - ASSERT_EQ(elw->getInput(0)->getProducer()->getNode()->getType(), mir::Operation::Type::input); - ASSERT_EQ(elw->getInput(1)->getProducer()->getNode()->getType(), mir::Operation::Type::input); + ASSERT_EQ(elw->getInputProducer(0)->getNode()->getType(), mir::Operation::Type::input); + ASSERT_EQ(elw->getInputProducer(1)->getNode()->getType(), mir::Operation::Type::input); } TEST(OptPass, combineTransposesOpOrder) @@ -139,10 +139,10 @@ TEST(OptPass, combineTransposesOpOrder) int n2 = in2->getId(); CombineTransposes pass; pass.run(&g); - ASSERT_EQ(g.getOutputs()[0]->getInput(0)->getProducer()->getNode()->getType(), + ASSERT_EQ(g.getOutputs()[0]->getInputProducer(0)->getNode()->getType(), mir::Operation::Type::add); // Order is preserved - ASSERT_EQ(n1, elw->getInput(0)->getProducer()->getNode()->getId()); - ASSERT_EQ(n2, elw->getInput(1)->getProducer()->getNode()->getId()); + ASSERT_EQ(n1, elw->getInputProducer(0)->getNode()->getId()); + ASSERT_EQ(n2, elw->getInputProducer(1)->getNode()->getId()); } } // unnamed namespace diff --git a/compiler/nnc/unittests/optimizations/SinkTest.cpp b/compiler/nnc/unittests/optimizations/SinkTest.cpp index 5687f3e..e394651 100644 --- a/compiler/nnc/unittests/optimizations/SinkTest.cpp +++ b/compiler/nnc/unittests/optimizations/SinkTest.cpp @@ -37,7 +37,7 @@ namespace Operation *getPrev(Operation *op) { assert(op->getNumInputs() == 1); - return op->getInput(0)->getProducer()->getNode(); + return op->getInputProducer(0)->getNode(); } Operation *getNext(Operation *op) diff --git a/compiler/nnc/unittests/soft_backend/CPPOperations.cpp b/compiler/nnc/unittests/soft_backend/CPPOperations.cpp index 8e0d05e..c161941 100644 --- a/compiler/nnc/unittests/soft_backend/CPPOperations.cpp +++ b/compiler/nnc/unittests/soft_backend/CPPOperations.cpp @@ -231,7 +231,7 @@ mir::TensorVariant getReferenceTensor(mir::Graph &g, interpreter.setInput("x" + to_string(i), *input_ntensors[i]); g.accept(&interpreter); const auto *output_op = g.getOutputs()[0]; - return interpreter.getResult(output_op->getInput(0)->getProducer()); + return interpreter.getResult(output_op->getInputProducer(0)); }; /** diff --git a/compiler/nnc/unittests/transformations/Switcher.cpp b/compiler/nnc/unittests/transformations/Switcher.cpp index 7be4835..7860959 100644 --- a/compiler/nnc/unittests/transformations/Switcher.cpp +++ b/compiler/nnc/unittests/transformations/Switcher.cpp @@ -44,10 +44,10 @@ TEST(TRANSFORMATIONS, Switcher_Conv2D_NCHW2NHWC) switcher.run(&g); - auto *trans_out = output->getInput(0)->getProducer()->getNode(); - auto *conv_ = trans_out->getInput(0)->getProducer()->getNode(); - auto *trans_in = conv_->getInput(0)->getProducer()->getNode(); - auto *input_ = trans_in->getInput(0)->getProducer()->getNode(); + auto *trans_out = output->getInputProducer(0)->getNode(); + auto *conv_ = trans_out->getInputProducer(0)->getNode(); + auto *trans_in = conv_->getInputProducer(0)->getNode(); + auto *input_ = trans_in->getInputProducer(0)->getNode(); ASSERT_EQ(trans_out->getType(), mir::Operation::Type::transpose); ASSERT_NE(conv_, conv); @@ -89,10 +89,10 @@ TEST(TRANSFORMATIONS, Switcher_DWConv2D_NHWC2NCHW) switcher.run(&g); - auto *trans_out = output->getInput(0)->getProducer()->getNode(); - auto *dw_conv_ = trans_out->getInput(0)->getProducer()->getNode(); - auto *trans_in = dw_conv_->getInput(0)->getProducer()->getNode(); - auto *input_ = trans_in->getInput(0)->getProducer()->getNode(); + auto *trans_out = output->getInputProducer(0)->getNode(); + auto *dw_conv_ = trans_out->getInputProducer(0)->getNode(); + auto *trans_in = dw_conv_->getInputProducer(0)->getNode(); + auto *input_ = trans_in->getInputProducer(0)->getNode(); ASSERT_EQ(trans_out->getType(), mir::Operation::Type::transpose); ASSERT_NE(dw_conv_, dw_conv); @@ -135,10 +135,10 @@ TEST(TRANSFORMATIONS, Switcher_DeConv2D_NHWC2NCHW) switcher.run(&g); - auto *trans_out = output->getInput(0)->getProducer()->getNode(); - auto *deconv_ = trans_out->getInput(0)->getProducer()->getNode(); - auto *trans_in = deconv_->getInput(0)->getProducer()->getNode(); - auto *input_ = trans_in->getInput(0)->getProducer()->getNode(); + auto *trans_out = output->getInputProducer(0)->getNode(); + auto *deconv_ = trans_out->getInputProducer(0)->getNode(); + auto *trans_in = deconv_->getInputProducer(0)->getNode(); + auto *input_ = trans_in->getInputProducer(0)->getNode(); ASSERT_EQ(trans_out->getType(), mir::Operation::Type::transpose); ASSERT_NE(deconv_, deconv); @@ -179,10 +179,10 @@ TEST(TRANSFORMATIONS, Switcher_AvgPool2D_NHWC2NCHW) switcher.run(&g); - auto *trans_out = output->getInput(0)->getProducer()->getNode(); - auto *avg_pool_ = trans_out->getInput(0)->getProducer()->getNode(); - auto *trans_in = avg_pool_->getInput(0)->getProducer()->getNode(); - auto *input_ = trans_in->getInput(0)->getProducer()->getNode(); + auto *trans_out = output->getInputProducer(0)->getNode(); + auto *avg_pool_ = trans_out->getInputProducer(0)->getNode(); + auto *trans_in = avg_pool_->getInputProducer(0)->getNode(); + auto *input_ = trans_in->getInputProducer(0)->getNode(); ASSERT_EQ(trans_out->getType(), mir::Operation::Type::transpose); ASSERT_NE(avg_pool_, avg_pool); @@ -227,10 +227,10 @@ TEST(TRANSFORMATIONS, Switcher_MaxPool2D_NCHW2NHWC) switcher.run(&g); - auto *trans_out = output->getInput(0)->getProducer()->getNode(); - auto *max_pool_ = trans_out->getInput(0)->getProducer()->getNode(); - auto *trans_in = max_pool_->getInput(0)->getProducer()->getNode(); - auto *input_ = trans_in->getInput(0)->getProducer()->getNode(); + auto *trans_out = output->getInputProducer(0)->getNode(); + auto *max_pool_ = trans_out->getInputProducer(0)->getNode(); + auto *trans_in = max_pool_->getInputProducer(0)->getNode(); + auto *input_ = trans_in->getInputProducer(0)->getNode(); ASSERT_EQ(trans_out->getType(), mir::Operation::Type::transpose); ASSERT_NE(max_pool_, max_pool); -- 2.7.4