From: Сергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 Date: Mon, 28 Jan 2019 17:32:26 +0000 (+0300) Subject: [nnc] Prepare the interpreter backend for the future changes related to adding of... X-Git-Tag: nncc_backup~914 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ef53c37f7504e17de4764fe1fe6706475c653a91;p=platform%2Fcore%2Fml%2Fnnfw.git [nnc] Prepare the interpreter backend for the future changes related to adding of the Tensor class in the ModelIR (#2934) * Add `getInputTensors` and `setOutputTensors` methods in place of the `var` method. * Reformat the code according to coding style. Signed-off-by: Sergei Barannikov --- diff --git a/contrib/nnc/include/passes/interpreter/Interpreter.h b/contrib/nnc/include/passes/interpreter/Interpreter.h index dcc499c..f70d56b 100644 --- a/contrib/nnc/include/passes/interpreter/Interpreter.h +++ b/contrib/nnc/include/passes/interpreter/Interpreter.h @@ -17,26 +17,26 @@ #ifndef _NNC_BACKEND_INTERPRETER_CORE_INTERPRETER_ #define _NNC_BACKEND_INTERPRETER_CORE_INTERPRETER_ -#include -#include -#include -#include -#include - #include "core/modelIR/Visitor.h" #include "core/modelIR/Operation.h" +#include "core/modelIR/TensorVariant.h" +#include +#include +#include -#include "core/modelIR/Tensor.h" - -namespace nnc -{ -namespace mir -{ +namespace nnc { +namespace mir { class NNInterpreter : public IVisitor { public: explicit NNInterpreter() = default; + ~NNInterpreter() override = default; + + void setInput(const std::string& name, const TensorVariant& data); + + TensorVariant getResult(IODescriptor tensor); + void visit(ops::BatchNormOp& op) override; void visit(ops::BiasAddOp& op) override; void visit(ops::CappedReluOp& op) override; @@ -69,18 +69,20 @@ public: void visit(ops::TanhOp& op) override; void visit(ops::TransposeOp& op) override; - void setInput(const std::string &name, const TensorVariant& data); - TensorVariant getResult(IODescriptor tensor); void dump(Operation& op, bool all = false); - ~NNInterpreter() override = default; - private: - std::vector &var(size_t id); + /// @brief Gets the computed inputs for the operation. + std::vector> getInputTensors(const Operation& op); -private: - std::map> vars; - std::unordered_map data; + /// @brief Saves the computed outputs for the operation. + void setOutputTensors(const Operation& op, std::vector&& outputs); + + /// @brief Mapping of graph named inputs to their values. + std::unordered_map _inputTensors; + + /// @brief Mapping of operations to their computed results. + std::unordered_map> _opResults; }; } // namespace mir diff --git a/contrib/nnc/passes/interpreter/Interpreter.cpp b/contrib/nnc/passes/interpreter/Interpreter.cpp index fe54c50..80286d4 100644 --- a/contrib/nnc/passes/interpreter/Interpreter.cpp +++ b/contrib/nnc/passes/interpreter/Interpreter.cpp @@ -78,9 +78,19 @@ namespace nnc { using namespace nnc::mir; -std::vector &NNInterpreter::var(size_t id) { return vars[id]; } +std::vector> +NNInterpreter::getInputTensors(const Operation& op) { + std::vector> tensors; + for (IODescriptor ir_tensor : op.getPrevNodes()) + tensors.emplace_back(_opResults.at(ir_tensor.op->getId()).at(ir_tensor.index)); + return tensors; +} + +void NNInterpreter::setOutputTensors(const Operation& op, std::vector&& outputs) { + _opResults.emplace(op.getId(), std::move(outputs)); +} -static void dumpIndex (Index ndx) { +static void dumpIndex(Index ndx) { for (int i = 0; i < ndx.rank(); i++) { std::cout << (i ? "," : "(") << ndx.at(i); } @@ -88,19 +98,19 @@ static void dumpIndex (Index ndx) { } #if(0) - #define DUMP(x, y) dump(x, (y)) +#define DUMP(x, y) dump(x, (y)) #else - #define DUMP(x, y) +#define DUMP(x, y) #endif void NNInterpreter::dump(Operation& op, bool all) { // TODO: in theory there could be several outputs from the given 'op'. - TensorVariant tensor = var(op.getId())[0]; + TensorVariant tensor = _opResults.at(op.getId()).at(0); auto shape = tensor.getShape(); std::cout << "Tensor '" << - (op.getNextNodes().size() ? op.getNextNodes()[0]->getName() : "output") << - "' DType = " << (int)tensor.getDataType() << ", ElementSize = " << - tensor.getElementSize() << ", Shape" << shape; + (op.getNextNodes().size() ? op.getNextNodes()[0]->getName() : "output") << + "' DType = " << (int)tensor.getDataType() << ", ElementSize = " << + tensor.getElementSize() << ", Shape" << shape; std::cout << " ElementsNumber " << shape.numElements() << "\n"; static bool do_it = false; if (do_it || all) { @@ -118,208 +128,183 @@ void NNInterpreter::dump(Operation& op, bool all) { } } -void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) { -// TODO: our tests are failed with fe enable exception -// feenableexcept(FE_INVALID | FE_OVERFLOW); -// | -// FE_DIVBYZERO | -// FE_OVERFLOW | -// FE_UNDERFLOW); -// feenableexcept(FE_ALL_EXCEPT); +void NNInterpreter::setInput(const std::string& name, const TensorVariant& t) { + _inputTensors.emplace(name, t); +} - data.emplace(name, t); +TensorVariant NNInterpreter::getResult(IODescriptor tensor) { + return _opResults.at(tensor.op->getId()).at(tensor.index); } void NNInterpreter::visit(ops::InputOp& op) { - (void)op; - auto it = data.find(op.getName()); - if( it == data.end() ) - { + auto it = _inputTensors.find(op.getName()); + if (it == _inputTensors.end()) throw std::runtime_error("Can't find data for node \"" + op.getName() + ". Input data was not set correctly?"); - } - var(op.getId()) = {it->second}; + setOutputTensors(op, {it->second}); } void NNInterpreter::visit(ops::ConstantOp& op) { - assert(data.find(op.getName()) == data.end()); - var(op.getId()) = {op.getValue()}; -} - -TensorVariant NNInterpreter::getResult(IODescriptor tensor) { - return vars.at(tensor.op->getId()).at(tensor.index); + assert(_inputTensors.find(op.getName()) == _inputTensors.end()); + setOutputTensors(op, {op.getValue()}); } void NNInterpreter::visit(ops::ConcatOp& op) { - auto &operands = op.getPrevNodes(); - std::vector ins; - for (auto &in : operands) - { - ins.push_back(var(in.op->getId())[in.index]); - } - var(op.getId()) = Concat(ins, op.getOutputShape(0), op.getAxis())(); + auto inputs = getInputTensors(op); + auto outputs = Concat(inputs, op.getOutputShape(0), op.getAxis())(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::Conv2DOp& op) { - auto input = op.getPrevNodes()[0]; - auto kernel = op.getPrevNodes()[1]; - auto input_tensor = var(input.op->getId())[input.index]; - auto kernel_tensor = var(kernel.op->getId())[kernel.index]; - var(op.getId()) = Conv2D(input_tensor, kernel_tensor, op)(); + auto inputs = getInputTensors(op); + auto outputs = Conv2D(inputs[0], inputs[1], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, true); } void NNInterpreter::visit(ops::ReshapeOp& op) { - auto operand = op.getPrevNodes()[0]; - auto input = var(operand.op->getId())[operand.index]; - var(op.getId()) = Reshape(input, op.getOutputShape(0))(); + auto inputs = getInputTensors(op); + auto outputs = Reshape(inputs[0], op.getOutputShape(0))(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::ReluOp& op) { - auto operand = op.getPrevNodes()[0]; - Tensor input(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill( - op.getOutputShape(0), [&input](const Index &id) { return std::max(input.at(id), 0.0f); })(); + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), + [&input](const Index& id) { return std::max(input.at(id), 0.0f); })(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::SigmoidOp& op) { - auto operand = op.getPrevNodes()[0]; - Tensor input(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill(op.getOutputShape(0), [&input](const Index& id) { + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input](const Index& id) { return 1.f / (1.f + std::exp(-input.at(id))); })(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::SoftmaxOp& op) { - auto operand = op.getPrevNodes()[0]; - auto input = var(operand.op->getId())[operand.index]; - var(op.getId()) = Softmax(op.getInputShape(0), input, op.getAxis())(); + auto inputs = getInputTensors(op); + auto outputs = Softmax(op.getInputShape(0), inputs[0], op.getAxis())(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::PoolOp& op) { - auto operand = op.getPrevNodes()[0]; - auto input = var(operand.op->getId())[operand.index]; - var(op.getId()) = Pool(input, op)(); + auto inputs = getInputTensors(op); + auto outputs = Pool(inputs[0], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::FullyConnectedOp& op) { - auto operand1 = op.getPrevNodes()[0]; - auto operand2 = op.getPrevNodes()[1]; - TensorVariant input1 = var(operand1.op->getId())[operand1.index]; - TensorVariant input2 = var(operand2.op->getId())[operand2.index]; - var(op.getId()) = FullyConnected(input1, input2, op)(); + auto inputs = getInputTensors(op); + auto outputs = FullyConnected(inputs[0], inputs[1], op)(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::GemmOp& op) { - auto operand_a = op.getPrevNodes()[0]; - auto operand_b = op.getPrevNodes()[1]; - auto operand_c = op.getPrevNodes()[2]; - const TensorVariant input_a = var(operand_a.op->getId())[operand_a.index]; - const TensorVariant input_b = var(operand_b.op->getId())[operand_b.index]; - const TensorVariant input_c = var(operand_c.op->getId())[operand_c.index]; - var(op.getId()) = Gemm(input_a, input_b, input_c, op)(); + auto inputs = getInputTensors(op); + auto outputs = Gemm(inputs[0], inputs[1], inputs[2], op)(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::CappedReluOp& op) { - auto operand = op.getPrevNodes()[0]; - Tensor input(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill(op.getOutputShape(0), [&input, &op](const Index &id) { + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input, &op](const Index& id) { return std::min(std::max(input.at(id), 0.0f), op.getCap()); })(); + setOutputTensors(op, std::move(outputs)); } -void NNInterpreter::visit(ops::DepthwiseConv2DOp& op){ - auto input = op.getPrevNodes()[0]; - auto kernel = op.getPrevNodes()[1]; - auto input_tensor(var(input.op->getId())[input.index]); - auto kernel_tensor(var(kernel.op->getId())[kernel.index]); - var(op.getId()) = DepthwiseConv2D(input_tensor, kernel_tensor, op)(); +void NNInterpreter::visit(ops::DepthwiseConv2DOp& op) { + auto inputs = getInputTensors(op); + auto outputs = DepthwiseConv2D(inputs[0], inputs[1], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, true); } void NNInterpreter::visit(ops::BiasAddOp& op) { - auto operand1 = op.getPrevNodes()[0]; - auto operand2 = op.getPrevNodes()[1]; - auto input1 = var(operand1.op->getId())[operand1.index]; - auto input2 = var(operand2.op->getId())[operand2.index]; - var(op.getId()) = BiasAdd(input1, input2)(); + auto inputs = getInputTensors(op); + auto outputs = BiasAdd(inputs[0], inputs[1])(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::BatchNormOp& op) { - auto operand = op.getPrevNodes()[0]; - TensorVariant input(var(operand.op->getId())[operand.index]); - // TODO implement this - var(op.getId()) = BatchNorm(input, op)(); + auto inputs = getInputTensors(op); + auto outputs = BatchNorm(inputs[0], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::ScaleOp& op) { - auto operand1 = op.getPrevNodes()[0]; - auto operand2 = op.getPrevNodes()[1]; - auto input1 = var(operand1.op->getId())[operand1.index]; - auto input2 = var(operand2.op->getId())[operand2.index]; - var(op.getId()) = Scale(input1, input2)(); + auto inputs = getInputTensors(op); + auto outputs = Scale(inputs[0], inputs[1])(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } - void NNInterpreter::visit(ops::SliceOp& op) { - auto operand = op.getPrevNodes()[0]; - auto input = Tensor(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill(op.getOutputShape(0), [&input, &op](const Index& id) { + auto inputs = getInputTensors(op); + auto input = Tensor(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input, &op](const Index& id) { Index idx = nnc::shift(id, op.getStarts()); return input.at(idx); })(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::DropoutOp& op) { - auto operand = op.getPrevNodes()[0]; - TensorVariant input(var(operand.op->getId())[operand.index]); + auto inputs = getInputTensors(op); + TensorVariant input(inputs[0]); // TODO implement this - var(op.getId()) = Dropout(input, op)(); + auto outputs = Dropout(input, op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::TanhOp& op) { - auto operand = op.getPrevNodes()[0]; - Tensor input(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill(op.getOutputShape(0), [&input](const Index &id) { + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input](const Index& id) { return std::tanh(input.at(id)); })(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::ElementwiseOp& op) { - auto operands = op.getPrevNodes(); + auto inputs = getInputTensors(op); + std::vector> ins; // Reserve space for tensor variants to avoid reference invalidation when pushing into vector std::vector broadcasted{}; broadcasted.reserve(op.getNumInputs()); - for (auto &in : operands) { - auto& tmp = var(in.op->getId())[in.index]; + for (auto in : inputs) { if (op.getBroadcast()) { - broadcasted.emplace_back(tmp, op.getOutputShape(0)); + broadcasted.emplace_back(in, op.getOutputShape(0)); ins.emplace_back(broadcasted.back()); } else { - ins.emplace_back(tmp); + ins.emplace_back(in); } } - float (*func)(float,float); // Another dirty hack + float (* func)(float, float); // Another dirty hack switch (op.getOpType()) { case ops::ElementwiseOp::OpType::add: func = [](float a, float b) { return a + b; }; break; case ops::ElementwiseOp::OpType::mul: - func = [](float a, float b) { return a * b;}; + func = [](float a, float b) { return a * b; }; break; case ops::ElementwiseOp::OpType::max: - func = [](float a, float b) { return std::max(a,b);}; + func = [](float a, float b) { return std::max(a, b); }; break; case ops::ElementwiseOp::OpType::div: func = [](float a, float b) { return a / b; }; @@ -330,136 +315,119 @@ void NNInterpreter::visit(ops::ElementwiseOp& op) { default: assert(false && "Unsupported Optype"); } - var(op.getId()) = Fill( - op.getOutputShape(0), - [&func, &ins](const Index& id) { - float acc = ins[0].at(id); - for (size_t i = 1; i < ins.size(); i++) - acc = func(acc, ins[i].at(id)); - return acc; - })(); + auto outputs = Fill(op.getOutputShape(0), [&func, &ins](const Index& id) { + float acc = ins[0].at(id); + for (size_t i = 1; i < ins.size(); i++) + acc = func(acc, ins[i].at(id)); + return acc; + })(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::DeConv2DOp& op) { - auto input = op.getPrevNodes()[0]; - auto kernel = op.getPrevNodes()[1]; - auto input_tensor = var(input.op->getId())[input.index]; - auto kernel_tensor = var(kernel.op->getId())[kernel.index]; - var(op.getId()) = DeConv2D(input_tensor, kernel_tensor, op)(); + auto inputs = getInputTensors(op); + auto outputs = DeConv2D(inputs[0], inputs[1], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::EluOp& op) { - auto operand = op.getPrevNodes()[0]; - Tensor input(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill(op.getOutputShape(0), [&input, &op](const Index &id) { + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input, &op](const Index& id) { if (input.at(id) >= 0) return input.at(id); else - return op.getAlpha()*(expf(input.at(id))-1); + return op.getAlpha() * (expf(input.at(id)) - 1); })(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::SqueezeOp& op) { - auto operand = op.getPrevNodes()[0]; - auto& input = var(operand.op->getId())[operand.index]; - //Squeeze is just a special case of reshape - var(op.getId()) = Reshape(input, op.getOutputShape(0))(); + auto inputs = getInputTensors(op); + // Squeeze is just a special case of reshape. + auto outputs = Reshape(inputs[0], op.getOutputShape(0))(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::PadOp& op) { - auto operand = op.getPrevNodes()[0]; - auto& input = var(operand.op->getId())[operand.index]; - var(op.getId()) = Pad(input, op)(); + auto inputs = getInputTensors(op); + auto outputs = Pad(inputs[0], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::SqrtOp& op) { - auto operand = op.getPrevNodes()[0]; - auto input = Tensor(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill(op.getOutputShape(0), [&input](const Index id) { + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input](const Index id) { return sqrt(input.at(id)); })(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::ResizeOp& op) { - auto operand = op.getPrevNodes()[0]; - Tensor input(var(operand.op->getId())[operand.index]); - switch (op.getMode()) { - case ops::ResizeOp::ResizeMethod::nearestNeighbor: { - auto scales = op.getScales(); - var(op.getId()) = Fill(op.getOutputShape(0), [&scales, &input](const Index& id) { - Index in_idx; - in_idx.resize(4); - for (int i = 0; i < input.getShape().rank(); i++) { - in_idx.at(i) = static_cast (floorf(id.at(i) / scales[i])); - } - return input.at(in_idx); - })(); - break; + auto inputs = getInputTensors(op); + Tensor input(inputs[0]); + assert(op.getMode() == ops::ResizeOp::ResizeMethod::nearestNeighbor); + auto scales = op.getScales(); + auto outputs = Fill(op.getOutputShape(0), [&scales, &input](const Index& id) { + Index in_idx; + in_idx.resize(4); + for (int i = 0; i < input.getShape().rank(); i++) { + in_idx.at(i) = static_cast (floorf(id.at(i) / scales[i])); } - default: - assert(false && "Not supported Optype"); - } + return input.at(in_idx); + })(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::ReduceFOp& op) { + auto inputs = getInputTensors(op); + + assert(op.getFuncType() == ops::ReduceFOp::FuncType::mean); // should always be an integer in a float - const float reduction_area = - static_cast(op.getInputShape(0).numElements() / op.getOutputShape(0).numElements()); - - auto operand = op.getPrevNodes()[0]; - auto& input = var(operand.op->getId())[operand.index]; - - std::function func; - switch (op.getFuncType()) { - case ops::ReduceFOp::FuncType::mean: { - func = [](float running_sum, float item) { return running_sum + item; }; - var(op.getId()) = ReduceN(op.getInputShape(0), - op.getOutputShape(0), input, op.getReductionDims(), func)(); - Tensor out_t = Tensor(var(op.getId())[0]); // for numerical stability - var(op.getId()) = Fill(op.getOutputShape(0), - [&out_t, reduction_area](const Index& id) { - return out_t.at(id) / reduction_area; - })(); - } - break; - default: - assert(false && "Not Implemented"); - } + const float reduction_area = op.getInputShape(0).numElements() / + op.getOutputShape(0).numElements(); + + auto tmp = ReduceN(op.getInputShape(0), op.getOutputShape(0), inputs[0], + op.getReductionDims(), + [](float running_sum, float item) { return running_sum + item; })(); + Tensor out_t(tmp[0]); // for numerical stability + auto outputs = Fill(op.getOutputShape(0), [&out_t, reduction_area](const Index& id) { + return out_t.at(id) / reduction_area; + })(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::TransposeOp& op) { - auto operand = op.getPrevNodes()[0]; - auto& input = var(operand.op->getId())[operand.index]; - var(op.getId()) = Transpose(input, op)(); + auto inputs = getInputTensors(op); + auto outputs = Transpose(inputs[0], op)(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } void NNInterpreter::visit(ops::GatherOp& op) { - auto data_descr = op.getPrevNodes()[0]; - auto indices_descr = op.getPrevNodes()[1]; - const auto& data = var(data_descr.op->getId())[data_descr.index]; - const auto& indices = var(indices_descr.op->getId())[indices_descr.index]; - var(op.getId()) = Gather(data, indices, op)(); + auto inputs = getInputTensors(op); + auto outputs = Gather(inputs[0], inputs[1], op)(); + setOutputTensors(op, std::move(outputs)); } void NNInterpreter::visit(ops::LeakyReluOp& op) { - auto operand = op.getPrevNodes()[0]; + auto inputs = getInputTensors(op); float alpha = op.getAlpha(); - Tensor input(var(operand.op->getId())[operand.index]); - var(op.getId()) = Fill( - op.getOutputShape(0), [&input, alpha](const Index& id) { - float val = input.at(id); - return val > 0.0f ? val : val * alpha; - })(); - + Tensor input(inputs[0]); + auto outputs = Fill(op.getOutputShape(0), [&input, alpha](const Index& id) { + float val = input.at(id); + return val > 0.0f ? val : val * alpha; + })(); + setOutputTensors(op, std::move(outputs)); DUMP(op, false); } diff --git a/contrib/nnc/passes/interpreter/ops/Concat.h b/contrib/nnc/passes/interpreter/ops/Concat.h index 998d363..d87d088 100644 --- a/contrib/nnc/passes/interpreter/ops/Concat.h +++ b/contrib/nnc/passes/interpreter/ops/Concat.h @@ -25,16 +25,15 @@ namespace nnc template class Concat : public Fill { public: - explicit Concat(const std::vector &inputs, const mir::Shape &outputShape, - int32_t axis) - : Fill(outputShape, getSingleFunction(inputs, axis)) - { - } + Concat(const std::vector>& inputs, + const mir::Shape& outputShape, + int32_t axis) + : Fill(outputShape, getSingleFunction(inputs, axis)) {} private: - const std::function getSingleFunction(const std::vector &inputs, - int32_t axis) - { + const std::function + getSingleFunction(const std::vector>& inputs, + int32_t axis) { std::vector> inputAccessors; for (auto &in : inputs) {