From fcabe71517a6c7451573b05ffe717e3facb5333d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A2=D0=B8=D1=89?= =?utf8?q?=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Staff=20Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 21 Dec 2018 13:45:26 +0300 Subject: [PATCH] [nnc] The first ONNX model resnet50 works on NNC interpreter (#2718) Several operators were fixed: BatchNormalization, Reshape, Gemm and Pooling. Now NNC is available to convert the ONNX resnt50 network, play it back in interpreter and to produce the out which is totally comparable with reference data. Signed-off-by: Andrew V. Tischenko a.tischenko@partner.samsung.com --- contrib/nnc/core/modelIR/operations/GemmOp.cpp | 49 +---- contrib/nnc/core/modelIR/operations/PoolOp.cpp | 4 + contrib/nnc/include/core/modelIR/Graph.h | 5 - .../nnc/include/core/modelIR/operations/GemmOp.h | 19 +- .../nnc/include/passes/interpreter/Interpreter.h | 4 +- contrib/nnc/passes/interpreter/Interpreter.cpp | 94 ++++++++-- .../nnc/passes/interpreter/interpreter_pass.cpp | 4 +- contrib/nnc/passes/interpreter/ops/Gemm.h | 62 ++++++- .../nnc/passes/onnx_frontend/ONNXImporterImpl.cpp | 81 ++++---- .../nnc/passes/onnx_frontend/ONNXImporterImpl.h | 11 +- contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp | 206 +++++++++++++++------ contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h | 6 +- contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp | 7 +- .../passes/soft_backend/code_snippets/cpp_gemm.def | 34 ++++ 14 files changed, 393 insertions(+), 193 deletions(-) create mode 100644 contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def diff --git a/contrib/nnc/core/modelIR/operations/GemmOp.cpp b/contrib/nnc/core/modelIR/operations/GemmOp.cpp index d0dd275..d629f45 100644 --- a/contrib/nnc/core/modelIR/operations/GemmOp.cpp +++ b/contrib/nnc/core/modelIR/operations/GemmOp.cpp @@ -21,51 +21,18 @@ namespace mir { namespace ops { void GemmOp::inferOutputShapes() { -//Input tensor A: The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero. -//Input tensor B: The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero. -//Input tensor C: The shape of C should be unidirectional broadcastable to (M, N). -//Output tensor Y: shape (M, N) or vector(N) if M == 1. - std::vector vector(2); // this vector will be used to create shapes of tensor A, B + auto shape_a = getInputShape(0); + auto shape_b = getInputShape(1); + assert((shape_a.rank() == shape_b.rank()) && (shape_a.rank() == 2)); + assert(shape_a.dim(1) == shape_b.dim(0) && "Multiplicable"); - // Flatten the shape by dim(0) - mir::Shape shape0 = {getInputShape(0).dim(0), - getInputShape(0).numElements() / getInputShape(0).dim(0)}; - assert(shape0.rank() == 2); - Shape shape_a(2); - if (_transA) { - shape_a.dim(0) = shape0.dim(shape0.rank() - 1); - shape_a.dim(1) = shape0.dim(shape0.rank() - 2); - } else { - shape_a.dim(0) = shape0.dim(shape0.rank() - 2); - shape_a.dim(1) = shape0.dim(shape0.rank() - 1); - } - - auto& shape1 = getInputShape(1); - // It must be a matrice - assert(shape1.rank() == 2); - Shape shape_b(2); - - if (_transB) { - shape_b.dim(0) = shape1.dim(shape1.rank() - 1); - shape_b.dim(1) = shape1.dim(shape1.rank() - 2); - } else { - shape_b.dim(0) = shape1.dim(shape1.rank() - 2); - shape_b.dim(1) = shape1.dim(shape1.rank() - 1); - } - - // Number of cols in tensor A must be equal to number of rows in tensor B - assert(shape_a.dim(1) == shape_b.dim(0)); Shape mult_a_b({shape_a.dim(0), shape_b.dim(1)}); - Shape shape_c = getInputShape(2); + auto shape_c = getInputShape(2); + assert((mult_a_b == shape_c) || + (((shape_c.rank() == 1)) && (mult_a_b.dim(0) == 1) && + (mult_a_b.dim(1) == shape_c.dim(0)))); - if (shape_c.rank() == 1){ - assert(mult_a_b.dim(0) == 1); - assert(mult_a_b.dim(1) == shape_c.dim(0)); - } else { - assert(shape_c.rank() == 2); - assert((mult_a_b.dim(0) == shape_c.dim(0)) && (mult_a_b.dim(1) == shape_c.dim(1))); - } setOutputShape(0, mult_a_b); } diff --git a/contrib/nnc/core/modelIR/operations/PoolOp.cpp b/contrib/nnc/core/modelIR/operations/PoolOp.cpp index 742019b..318dd47 100644 --- a/contrib/nnc/core/modelIR/operations/PoolOp.cpp +++ b/contrib/nnc/core/modelIR/operations/PoolOp.cpp @@ -61,6 +61,10 @@ void PoolOp::inferOutputShapes() { assert(false); } + for (int i = 0; i < output_shape.rank(); i++) { + assert(output_shape.dim(i) >= 0); + } + setOutputShape(0, output_shape); } diff --git a/contrib/nnc/include/core/modelIR/Graph.h b/contrib/nnc/include/core/modelIR/Graph.h index 3456c19..58b8ea9 100644 --- a/contrib/nnc/include/core/modelIR/Graph.h +++ b/contrib/nnc/include/core/modelIR/Graph.h @@ -89,10 +89,6 @@ class Graph { */ ops::VariableOp* replaceWithInputNode(const Operation* op); - void setConstants(std::set consts) { - _constants = consts; - } - /** * @brief Change graph inputs to nodes with names in newInputs * @param new_inputs names of nodes to be made into input nodes @@ -123,7 +119,6 @@ class Graph { _ops.push_back(op); } - void registerOp(ops::ConstantOp* op) { _constants.insert(op); _ops.push_back(op); diff --git a/contrib/nnc/include/core/modelIR/operations/GemmOp.h b/contrib/nnc/include/core/modelIR/operations/GemmOp.h index 4c10af9..bb90c3c 100644 --- a/contrib/nnc/include/core/modelIR/operations/GemmOp.h +++ b/contrib/nnc/include/core/modelIR/operations/GemmOp.h @@ -26,28 +26,13 @@ namespace ops { class GemmOp : public Operation { public: - GemmOp(const IODescriptor a, const IODescriptor b, const IODescriptor c, - bool transA, bool transB, float alpha, float beta) : - Operation(Type::gemmOp, {a, b, c}), - _a(a), _b(b), _c(c), _transA(transA),_transB(transB), _alpha(alpha), _beta(beta) { + GemmOp(IODescriptor arg, const IODescriptor b, const IODescriptor c) : + Operation(Type::gemmOp, {arg, b, c}) { inferOutputShapes(); } - bool getTransA() {return _transA;} - bool getTransB() {return _transB;} - float getAlpha() {return _alpha;} - float getBeta() {return _beta;} - private: void inferOutputShapes(); - - const IODescriptor _a; - const IODescriptor _b; - const IODescriptor _c; - bool _transA; - bool _transB; - float _alpha; - float _beta; }; } // namespace ops } // namespace mir diff --git a/contrib/nnc/include/passes/interpreter/Interpreter.h b/contrib/nnc/include/passes/interpreter/Interpreter.h index 267ee28..e6f10d6 100644 --- a/contrib/nnc/include/passes/interpreter/Interpreter.h +++ b/contrib/nnc/include/passes/interpreter/Interpreter.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "core/modelIR/Visitor.h" #include "core/modelIR/Operation.h" @@ -44,12 +45,12 @@ public: void visit(ops::Conv2DOp& op) override; void visit(ops::DeConv2DOp& op) override; void visit(ops::DepthwiseConv2DOp& op) override; + void visit(ops::GemmOp& op) override; void visit(ops::DropoutOp& op) override; void visit(ops::ElementwiseOp& op) override; void visit(ops::EluOp& op) override; void visit(ops::FullyConnectedOp& op) override; void visit(ops::GatherOp& op) override; - void visit(ops::GemmOp& op) override; void visit(ops::PadOp& op) override; void visit(ops::PoolOp& op) override; void visit(ops::ReduceFOp& op) override; @@ -68,6 +69,7 @@ public: void setInput(const std::string &name, const TensorVariant& data); std::vector &getResult(Operation* op); + void dump(Operation& op, bool all = false); ~NNInterpreter() override = default; diff --git a/contrib/nnc/passes/interpreter/Interpreter.cpp b/contrib/nnc/passes/interpreter/Interpreter.cpp index 6f5da6e..8e195d4 100644 --- a/contrib/nnc/passes/interpreter/Interpreter.cpp +++ b/contrib/nnc/passes/interpreter/Interpreter.cpp @@ -66,9 +66,11 @@ #include "ops/Pad.h" #include "ops/common.h" -#include #include #include +#include +#include +#include namespace nnc { @@ -76,32 +78,79 @@ using namespace nnc::mir; std::vector &NNInterpreter::var(size_t id) { return vars[id]; } -void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) { data.emplace(name, t); } +static void dumpIndex (Index ndx) { + for (int i = 0; i < ndx.rank(); i++) { + std::cout << (i ? "," : "(") << ndx.at(i); + } + std::cout << ")\t"; +} + +#if(0) + #define DUMP(x, y) dump(x, (y)) +#else + #define DUMP(x, y) +#endif + +void NNInterpreter::dump(Operation& op, bool all) { + // TODO: in theory there could be several outputs from the given 'op'. + TensorVariant tensor = var(op.getId())[0]; + std::cout << "Tensor '" << op.getName() << "' DType = " << (int)tensor.getDataType() << ", ElementSize = " << tensor.getElementSize() + << ", Shape = {"; + auto shape = tensor.getShape(); + for (int i = 0; i < shape.rank(); i++) { + std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", "); + } + std::cout << "ElementsNumber " << shape.numElements() << "\n"; + static bool do_it = false; + if (do_it || all) { + auto last_idx = shape.rank() - 1; + for (auto idx : ShapeRange(shape)) { + if (!idx.at(last_idx)) + std::cout << "\n"; + dumpIndex(idx); + if (tensor.getDataType() == DTYPE::FLOAT32) + std::cout << *(float_t*)tensor.at(idx) << "\t"; + else + std::cout << *(int32_t*)tensor.at(idx) << "\t"; + } + std::cout << "\n"; + } +} + +void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) { +// TODO: our tests are failed with fe enable exception +// feenableexcept(FE_INVALID | FE_OVERFLOW); +// | +// FE_DIVBYZERO | +// FE_OVERFLOW | +// FE_UNDERFLOW); +// feenableexcept(FE_ALL_EXCEPT); + + data.emplace(name, t); +} void NNInterpreter::visit(ops::VariableOp& op) { (void)op; auto it = data.find(op.getName()); if( it == data.end() ) { - throw std::runtime_error("Can't find data for node \"" + op.getName() + ". Input data was not set correctly?"); + throw std::runtime_error("Can't find data for node \"" + op.getName() + + ". Input data was not set correctly?"); } var(op.getId()) = {it->second}; } void NNInterpreter::visit(ops::ConstantOp& op) { + assert(data.find(op.getName()) == data.end()); var(op.getId()) = {op.getValue()}; } std::vector &NNInterpreter::getResult(Operation* op) { auto res = vars.find(op->getId()); if (res != vars.end()) - { return res->second; - } else - { - throw std::runtime_error("No such value"); - } + throw std::runtime_error("No such value: " + std::to_string(op->getId())); } void NNInterpreter::visit(ops::ConcatOp& op) { @@ -112,6 +161,7 @@ void NNInterpreter::visit(ops::ConcatOp& op) { ins.push_back(var(in.op->getId())[in.index]); } var(op.getId()) = Concat(ins, op.getOutputShape(0), op.getAxis())(); + DUMP(op, false); } void NNInterpreter::visit(ops::Conv2DOp& op) { @@ -123,6 +173,7 @@ void NNInterpreter::visit(ops::ReshapeOp& op) { auto operand = op.getPrevNodes()[0]; auto input = var(operand.op->getId())[operand.index]; var(op.getId()) = Reshape(input, op.getOutputShape(0))(); + DUMP(op, false); } void NNInterpreter::visit(ops::ReluOp& op) { @@ -130,6 +181,7 @@ void NNInterpreter::visit(ops::ReluOp& op) { Tensor input(var(operand.op->getId())[operand.index]); var(op.getId()) = Fill( op.getOutputShape(0), [&input](const Index &id) { return std::max(input.at(id), 0.0f); })(); + DUMP(op, false); } void NNInterpreter::visit(ops::SigmoidOp& op) { @@ -144,12 +196,14 @@ void NNInterpreter::visit(ops::SoftmaxOp& op) { auto operand = op.getPrevNodes()[0]; auto input = var(operand.op->getId())[operand.index]; var(op.getId()) = Softmax(op.getInputShape(0), input, op.getAxis())(); + DUMP(op, false); } void NNInterpreter::visit(ops::PoolOp& op) { auto operand = op.getPrevNodes()[0]; auto input = var(operand.op->getId())[operand.index]; var(op.getId()) = Pool(input, op)(); + DUMP(op, false); } void NNInterpreter::visit(ops::FullyConnectedOp& op) { @@ -159,9 +213,13 @@ void NNInterpreter::visit(ops::FullyConnectedOp& op) { } void NNInterpreter::visit(ops::GemmOp& op) { - auto operand = op.getPrevNodes()[0]; - TensorVariant input = var(operand.op->getId())[operand.index]; - var(op.getId()) = Gemm(input, op)(); + auto operand_a = op.getPrevNodes()[0]; + auto operand_b = op.getPrevNodes()[1]; + auto operand_c = op.getPrevNodes()[2]; + const TensorVariant input_a = var(operand_a.op->getId())[operand_a.index]; + const TensorVariant input_b = var(operand_b.op->getId())[operand_b.index]; + const TensorVariant input_c = var(operand_c.op->getId())[operand_c.index]; + var(op.getId()) = Gemm(input_a, input_b, input_c, op)(); } void NNInterpreter::visit(ops::CappedReluOp& op) { @@ -182,6 +240,7 @@ void NNInterpreter::visit(ops::BiasAddOp& op) { auto operand = op.getPrevNodes()[0]; auto input = var(operand.op->getId())[operand.index]; var(op.getId()) = BiasAdd(input, op.getWeights(), op.getOutputShape(0))(); + DUMP(op, false); } void NNInterpreter::visit(ops::BatchNormOp& op) { @@ -189,13 +248,15 @@ void NNInterpreter::visit(ops::BatchNormOp& op) { TensorVariant input(var(operand.op->getId())[operand.index]); // TODO implement this var(op.getId()) = BatchNorm(input, op)(); + DUMP(op, false); } void NNInterpreter::visit(ops::ScaleOp& op) { auto operand = op.getPrevNodes()[0]; TensorVariant input(var(operand.op->getId())[operand.index]); // TODO implement this - var(op.getId()) = Scale(input, op)(); + var(op.getId()) = Scale(input, op)(); + DUMP(op, false); } @@ -213,6 +274,7 @@ void NNInterpreter::visit(ops::DropoutOp& op) { TensorVariant input(var(operand.op->getId())[operand.index]); // TODO implement this var(op.getId()) = Dropout(input, op)(); + DUMP(op, false); } void NNInterpreter::visit(ops::TanhOp& op) { @@ -267,11 +329,13 @@ void NNInterpreter::visit(ops::ElementwiseOp& op) { acc = func(acc, ins[i].at(id)); return acc; })(); + DUMP(op, false); } void NNInterpreter::visit(ops::DeConv2DOp& op) { auto operand = op.getPrevNodes()[0]; var(op.getId()) = DeConv2D(var(operand.op->getId())[operand.index], op)(); + DUMP(op, false); } void NNInterpreter::visit(ops::EluOp& op) { @@ -283,6 +347,7 @@ void NNInterpreter::visit(ops::EluOp& op) { else return op.getAlpha()*(expf(input.at(id))-1); })(); + DUMP(op, false); } void NNInterpreter::visit(ops::SqueezeOp& op) { @@ -290,12 +355,14 @@ void NNInterpreter::visit(ops::SqueezeOp& op) { auto& input = var(operand.op->getId())[operand.index]; //Squeeze is just a special case of reshape var(op.getId()) = Reshape(input, op.getOutputShape(0))(); + DUMP(op, false); } void NNInterpreter::visit(ops::PadOp& op) { auto operand = op.getPrevNodes()[0]; auto& input = var(operand.op->getId())[operand.index]; var(op.getId()) = Pad(input, op)(); + DUMP(op, false); } void NNInterpreter::visit(ops::SqrtOp& op) { @@ -325,6 +392,7 @@ void NNInterpreter::visit(ops::ResizeOp& op) { default: assert(false && "Not supported Optype"); } + DUMP(op, false); } @@ -352,12 +420,14 @@ void NNInterpreter::visit(ops::ReduceFOp& op) { default: assert(false && "Not Implemented"); } + DUMP(op, false); } void NNInterpreter::visit(ops::TransposeOp& op) { auto operand = op.getPrevNodes()[0]; auto& input = var(operand.op->getId())[operand.index]; var(op.getId()) = Transpose(input, op)(); + DUMP(op, false); } void NNInterpreter::visit(ops::GatherOp& op) { diff --git a/contrib/nnc/passes/interpreter/interpreter_pass.cpp b/contrib/nnc/passes/interpreter/interpreter_pass.cpp index 58365cd..792e54c 100644 --- a/contrib/nnc/passes/interpreter/interpreter_pass.cpp +++ b/contrib/nnc/passes/interpreter/interpreter_pass.cpp @@ -105,7 +105,7 @@ PassData InterpreterPass::run(PassData data) { // Check nodes const auto& outputs = g->collectOutputs(); - + for (auto& out : outputs) { auto outputNode = interpreter.getResult(out); if (outputNode.empty()) { @@ -126,7 +126,7 @@ PassData InterpreterPass::run(PassData data) { #else std::cout << "Result <" << out_node->getName() << "> wasn't saved, due to lack of HDF5" << std::endl; - + #endif // NNC_HDF5_SUPPORTED if (is_several_outs) delete out_data; diff --git a/contrib/nnc/passes/interpreter/ops/Gemm.h b/contrib/nnc/passes/interpreter/ops/Gemm.h index e130626..0d94813 100644 --- a/contrib/nnc/passes/interpreter/ops/Gemm.h +++ b/contrib/nnc/passes/interpreter/ops/Gemm.h @@ -17,27 +17,71 @@ #ifndef _NNC_CORE_BACKEND_INTERPRETER_GEMM_ #define _NNC_CORE_BACKEND_INTERPRETER_GEMM_ -#include "core/modelIR/ShapeRange.h" #include "core/modelIR/operations/GemmOp.h" +#include "core/modelIR/ShapeRange.h" +#include "core/modelIR/TensorVariant.h" #include "OperationImpl.h" -namespace nnc -{ +namespace nnc { template -class Gemm : public OperationImpl -{ +class Gemm : public OperationImpl { public: - Gemm(const mir::TensorVariant &_input, const mir::ops::GemmOp &_op) : - _op(_op), _input(_input) {} + Gemm(const mir::TensorVariant& a, const mir::TensorVariant& b, const mir::TensorVariant& c, + mir::ops::GemmOp& op) : _op(op), _tensor_a(a), _tensor_b(b), _tensor_c(c) {} std::vector operator()() override { mir::TensorVariant res = OperationImpl::allocate_tensor(_op.getOutputShape(0)); + mir::Tensor accessor(res); + mir::ShapeRange out_range(res.getShape()); + +// mir::Tensor tensor_b(_b); + auto b_shape = _tensor_b.getShape(); + int32_t b_rank = b_shape.rank(); + + auto& in_shape = _tensor_a.getShape(); + int32_t in_rank = in_shape.rank(); + assert(in_shape.dim(in_rank - 1) == b_shape.dim(b_rank - 2)); + (void)in_rank; + + // First, we have to multply _input(which is alpha*tensorA) and tensor_b + auto len = b_shape.dim(b_rank - 2); + int32_t row; + int32_t col; + for (auto &out_idx : out_range) { + mir::Index t_idx = out_idx; + T& output_element = accessor.at(out_idx); + col = t_idx.at(-1); + row = t_idx.at(-2); + for (int32_t i = 0; i < len; ++i) { + t_idx.at(-1) = i; + T& in = _tensor_a.at(t_idx); + t_idx.at(-1) = col; + t_idx.at(-2) = i; + T& w = _tensor_b.at(t_idx); + t_idx.at(-2) = row; + output_element += w * in; + } + } + + // Now we have to add result of multiplication and (beta*tensor_c) + // We'd like to broadcast Tensor C to the output shape + assert(_op.getOutputShape(0).rank() == 2); + assert((_op.getOutputShape(0).rank() == _op.getInputShape(2).rank()) || + ((_op.getInputShape(2).rank() == 1) && (_op.getOutputShape(0).dim(0) == 1))); + + auto t = mir::TensorVariant (_tensor_c, _op.getOutputShape(0)); + mir::Tensor tensor_c(t); + for (auto idx : mir::ShapeRange(_op.getOutputShape(0))) { + accessor.at(idx) += tensor_c.at(idx); + } return {res}; } private: - const mir::ops::GemmOp& _op; - const mir::Tensor _input; + mir::ops::GemmOp& _op; + mir::Tensor _tensor_a; + mir::Tensor _tensor_b; + const mir::TensorVariant _tensor_c; }; } // namespace nnc diff --git a/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp b/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp index d3e9e64..d9643c4 100644 --- a/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp +++ b/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp @@ -21,12 +21,14 @@ #include "core/modelIR/IrDotDumper.h" #include "core/modelIR/operations/ConstantOp.h" -#include "core/modelIR/Operation.h" -#include "core/modelIR/Shape.h" -#include "core/modelIR/TensorVariant.h" #include "core/modelIR/operations/Conv2DOp.h" #include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/TransposeOp.h" #include "core/modelIR/operations/VariableOp.h" +#include "core/modelIR/Operation.h" +#include "core/modelIR/Shape.h" +#include "core/modelIR/TensorUtil.h" +#include "core/modelIR/TensorVariant.h" #include "onnx/onnx_pb.h" #include "onnx/proto_utils.h" #include "passes/common_frontend/model_allocation.h" @@ -35,6 +37,8 @@ #include "ONNXImporterImpl.h" #include "ONNXPerfectHash.h" +#include "ONNXOpCreator.h" + namespace nnc { @@ -85,6 +89,7 @@ static mir::TensorVariant createTensor(const onnx::TensorProto* tensor) { size_t element_size; size_t buffer_size; const char* src_data; + auto shape = ShapeHelper::createShape(tensor->dims(), static_cast(tensor->dims_size())); if (tensor->float_data_size() != 0) { element_size = sizeof(float); @@ -100,12 +105,20 @@ static mir::TensorVariant createTensor(const onnx::TensorProto* tensor) { element_size = sizeof(int32_t); buffer_size = tensor->int32_data_size() * element_size; src_data = reinterpret_cast(tensor->int32_data().data()); - throw PassException("WARNING: We don't support int32 tensors yet, investigate\n"); + mir::DTYPE type = mir::DTYPE::INT32; } else if (tensor->int64_data_size() != 0) { - element_size = sizeof(int64_t); + // FIXME: we could lose the data here + type = mir::DTYPE::INT32; + element_size = sizeof(int32_t); buffer_size = tensor->int64_data_size() * element_size; - src_data = reinterpret_cast(tensor->int64_data().data()); - throw PassException("WARNING: We don't support int64 tensors yet, investigate\n"); + + auto src_data64 = reinterpret_cast(tensor->int64_data().data()); + std::shared_ptr shared_buffer (new char[buffer_size], std::default_delete()); + auto dst_data = reinterpret_cast(shared_buffer.get()); + for (int i = 0; i < tensor->int64_data_size(); i++) { + dst_data[i] = (int32_t)src_data64 [i]; + } + return mir::TensorVariant(shape, shared_buffer, type, element_size); } else if (tensor->raw_data().size() != 0) { switch ((tensor->data_type())) { case onnx::TensorProto_DataType_FLOAT: @@ -123,7 +136,6 @@ static mir::TensorVariant createTensor(const onnx::TensorProto* tensor) { std::shared_ptr data(new char[buffer_size], std::default_delete()); memcpy(data.get(), src_data, buffer_size); - auto shape = ShapeHelper::createShape(tensor->dims(), static_cast(tensor->dims_size())); return mir::TensorVariant(shape, data, type, element_size); } @@ -131,11 +143,7 @@ void ONNXImporterImpl::createGraphInputs() { auto& graph = _model->graph(); auto& initializer = graph.initializer(); auto& value_info = graph.value_info(); - auto init_size = graph.initializer_size(); - auto val_size = graph.value_info_size(); - auto inp_size = graph.input_size(); std::map onnx_tensors; - std::set constants; // Collect all initializers of the given graph for (int i = 0; i < graph.initializer_size(); i++) { @@ -153,10 +161,10 @@ void ONNXImporterImpl::createGraphInputs() { _inputTensors.insert(std::make_pair(name, createTensor(onnx_tensor))); auto constant = _graph->create(name, _inputTensors.at(name)); _tensorNameToPrevMirOp[name] = constant; - constants.insert(constant); } else { - // We're dealing with graph input + // We're dealing with graph input (assuming the picture only) auto onnx_input_shape = input.type().tensor_type().shape(); + assert(onnx_input_shape.dim_size() == 4); mir::Shape shape(4); for (int i = 0; i < onnx_input_shape.dim_size(); i++) { assert(onnx_input_shape.dim(i).has_dim_value()); @@ -167,31 +175,28 @@ void ONNXImporterImpl::createGraphInputs() { _tensorNameToPrevMirOp[name] = node; } } - if (!constants.empty()) - _graph->setConstants(constants); } -static void dumpShape(const mir::Shape& shape) { - std::cout << "{"; - for (int i = 0; i < shape.rank(); i++) { - std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", "); - } -} - -void ONNXImporterImpl::dump(const std::vector& ops, const onnx::NodeProto& onnx_node) { +void ONNXImporterImpl::dump(const std::vector& inputs, + const std::vector& ops, + const onnx::NodeProto& onnx_node) { for (auto op : ops) { - std::cout << onnx_node.op_type() << " '" << op->getName() << "' Input Shapes: "; - for (int i = 0; i < op->getNumInputs() ; i++) { - dumpShape(op->getInputShape(i)); - } - std::cout << " Output Shapes: "; - for (int i = 0; i < op->getNumOutputs() ; i++) { - dumpShape(op->getOutputShape(i)); + std::cout << onnx_node.op_type() << " '" << op->getName() << "'"; + if (inputs[0]->getNumInputs() > 0) { + std::cout << "Input Shape: "; + dumpShape(inputs[0]->getOutputShape(0)); } + std::cout << " Output Shape: "; + dumpShape(op->getOutputShape(0)); auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(onnx_node.op_type().c_str(), onnx_node.op_type().size()); switch (onnx_op_type->opCode) { case ONNXOpCode::opConv: { auto *conv = dynamic_cast(op); + if (conv == nullptr) { + assert(dynamic_cast(op) != nullptr); + conv = dynamic_cast(op->getPrevNodes()[0].op); + } + assert(conv); std::cout << " Weights tensor shape "; dumpShape(conv->getKernel().getShape()); std::cout << " Strides "; @@ -204,6 +209,11 @@ void ONNXImporterImpl::dump(const std::vector& ops, const onnx: case ONNXOpCode::opAveragePool: case ONNXOpCode::opMaxPool: { auto *pool = dynamic_cast(op); + if (pool == nullptr) { + assert(dynamic_cast(op) != nullptr); + pool = dynamic_cast(op->getPrevNodes()[0].op); + } + assert(pool); std::cout << " Kernel "; dumpShape(pool->getWindowShape()); std::cout << " Strides "; @@ -250,9 +260,6 @@ mir::Graph *ONNXImporterImpl::createIR() { auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(op_type, onnx_node.op_type().size()); switch (onnx_op_type->opCode) { - //case ONNXOpCode::opIdentity: - // TOD: We simply remove the operation because it does nothing. Is it OK? - // break; case ONNXOpCode::opConv: outputs = _opCreator.convertConv2D(input_nodes, onnx_node); break; @@ -283,7 +290,7 @@ mir::Graph *ONNXImporterImpl::createIR() { outputs = _opCreator.convertConcat(input_nodes, onnx_node); break; case ONNXOpCode::opReshape: - outputs = _opCreator.convertReshape(input_nodes[0], input_nodes[1]->getOutputShape(0)); + outputs = _opCreator.convertReshape(input_nodes); break; case ONNXOpCode::opRelu: outputs = _opCreator.convertRelu(input_nodes); @@ -310,7 +317,7 @@ mir::Graph *ONNXImporterImpl::createIR() { throw PassException("Invalid ONNXOpCode" + std::to_string((int)onnx_op_type->opCode)); } // Set outputs' names - for (int i = 0; i < outputs.size(); i++){ + for (int i = 0; i < outputs.size(); i++) { outputs[i]->setName(onnx_node.output(i)); auto result = _tensorNameToPrevMirOp.emplace(outputs[i]->getName(), outputs[i]); if(!result.second) @@ -319,7 +326,7 @@ mir::Graph *ONNXImporterImpl::createIR() { assert (outputs.size()); // FIXME: it should be done properly via the given graph outputs _graphOutputs.assign(outputs.begin(), outputs.end()); - dump(outputs, onnx_node); + dump(input_nodes, outputs, onnx_node); } // set graph outputs // TODO: it should be done with onnx graph outputs diff --git a/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h b/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h index 3f0cd3f..7fd5683 100644 --- a/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h +++ b/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h @@ -39,9 +39,16 @@ public: void import() {}; mir::Graph *createIR() override; - void dump(const std::vector& op, const onnx::NodeProto& onnx_node); + void dump(const std::vector& inputs, const std::vector& op, + const onnx::NodeProto& onnx_node); -private: + static void dumpShape(mir::Shape shape) { + std::cout << "{"; + for (int i = 0; i < shape.rank(); i++) { + std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", "); + } + } + private: void createGraphInputs(); // This map maps onnx tensor names to MIR operations/nodes std::map _tensorNameToPrevMirOp; diff --git a/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp b/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp index 987ff29..30750f0 100644 --- a/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp +++ b/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp @@ -39,11 +39,13 @@ #include "core/modelIR/operations/ScaleOp.h" #include "core/modelIR/operations/SigmoidOp.h" #include "core/modelIR/operations/SoftmaxOp.h" +#include "core/modelIR/operations/TransposeOp.h" #include "core/modelIR/operations/VariableOp.h" #include "core/modelIR/operations/ElementwiseOp.h" #include "passes/common_frontend/shape_helper.h" #include "pass/PassException.h" #include "ONNXOpCreator.h" +#include "ONNXImporterImpl.h" namespace nnc { @@ -80,16 +82,17 @@ static std::pair getFloatAttribute(const onnx::NodeProto& onnx_node return {false, 0.0}; } -static TensorVariant createTensor(float value) { +// Create vector tensor filled with the given value +static TensorVariant createTensor(float value, const mir::Shape& shape) { mir::DTYPE element_type = mir::DTYPE::FLOAT32; size_t element_size = sizeof(value); - size_t buffer_size = 1 * element_size; - const char* src_data = reinterpret_cast(&value); - std::shared_ptr data(new char[buffer_size], std::default_delete()); - std::memcpy(data.get(), src_data, buffer_size); - Shape shape{1}; - return mir::TensorVariant(shape, data, element_type, element_size); + float* dst_ptr = new float[shape.numElements()]; + for (int i = 0; i < shape.numElements(); i++) { + dst_ptr[i] = value; + } + std::shared_ptr data((char*)dst_ptr, std::default_delete()); + return mir::TensorVariant({shape.numElements()}, data, element_type, element_size); } struct KernelStridesPadding { @@ -143,12 +146,14 @@ std::vector ONNXOpCreator::convertConv2D(InputOps& inputs, inputs.resize(1); std::vector outputs; - outputs = createOp(inputs[0]->getOutput(0), transposed, cdata.strides_shape, + // Transpose ONNX NCHW to MIR NHWC + auto t_input = convertONNXToMIR(inputs[0]->getOutput(0)); + outputs = createOp(t_input[0]->getOutput(0), transposed, cdata.strides_shape, cdata.padding_before, cdata.padding_after); if (input_bias) outputs = createOp(outputs[0]->getOutput(0), input_bias->getValue()); - return outputs; + return convertMIRToONNX(outputs[0]->getOutput(0)); } std::vector ONNXOpCreator::convertConcat(InputOps& inputs, @@ -180,18 +185,22 @@ std::vector ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode std::vector result; KernelStridesPadding cdata; + // Transpose ONNX NCHW to MIR NHWC + auto t_input = convertONNXToMIR(inputs[0]->getOutput(0)); switch (op_code) { - case ONNXOpCode::opGlobalAveragePool: + case ONNXOpCode::opGlobalAveragePool: { // GlobalAveragePool is equivalent to AveragePool with kernel size equal // to the spatial dimension of input tensor - return createOp(inputs[0]->getOutput(0), - ops::PoolOp::PoolingType::AVG, - inputs[0]->getOutputShape(0), // kernel_shape - Shape({1, 1}), // strides_shape - cdata.padding_before, cdata.padding_after, - ops::PoolOp::BorderType::ZEROFILLED, - ops::PoolOp::RoundMode::floor); + result = createOp(t_input[0]->getOutput(0), + ops::PoolOp::PoolingType::AVG, + t_input[0]->getOutputShape(0), // kernel_shape + Shape({1, 1}), // strides_shape + cdata.padding_before, cdata.padding_after, + ops::PoolOp::BorderType::ZEROFILLED, + ops::PoolOp::RoundMode::floor); + return convertMIRToONNX(result[0]->getOutput(0)); + } case ONNXOpCode::opAveragePool: border_type = ops::PoolOp::BorderType::ZEROFILLED; pool_type = ops::PoolOp::PoolingType::AVG; @@ -206,16 +215,16 @@ std::vector ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode // Proceed with Average or Max Pool getKernelStridesPadding(onnx_node, cdata); - result = createOp(inputs[0]->getOutput(0), pool_type, + result = createOp(t_input[0]->getOutput(0), pool_type, cdata.kernel_shape, cdata.strides_shape, cdata.padding_before, cdata.padding_after, border_type, ops::PoolOp::RoundMode::floor); - return result; + return convertMIRToONNX(result[0]->getOutput(0)); } std::vector ONNXOpCreator::convertSoftmax(InputOps& inputs, - const onnx::NodeProto& onnx_node) { + const onnx::NodeProto& onnx_node) { int axis; bool found; std::tie (found, axis) = getIntAttribute(onnx_node); @@ -223,8 +232,36 @@ std::vector ONNXOpCreator::convertSoftmax(InputOps& inputs, return createOp(inputs[0]->getOutput(0), axis); } -std::vector ONNXOpCreator::convertReshape(Operation* inputData, Shape outputShape) { - auto outputs = createOp(inputData->getOutput(0), outputShape); +std::vector ONNXOpCreator::convertReshape(InputOps& inputs) { + // The original shape + auto in_shape = inputs[0]->getInputShape(0); + + // Input tensor describing the new shape + // TODO: could it be not a constant? + auto* op = dynamic_cast(inputs[1]); + assert(op && "We support constants only"); + auto shape_tensor = op->getValue(); + Shape shape_tensor_shape = (shape_tensor).getShape(); + assert(shape_tensor_shape.rank() == 1); + // The rank of the new shape + auto cnt = shape_tensor_shape.numElements(); + // The vector to build the new shape from + std::vector shape_vector(cnt); + ShapeRange out_range(shape_tensor_shape); + Tensor tensor_accessor(shape_tensor); + + int i = 0; + for (auto idx : out_range) { + if (tensor_accessor.at(idx) == 0) + shape_vector[i] = in_shape.dim(i); + else if (tensor_accessor.at(idx) == -1) + shape_vector[i] = Shape::autoDim; + else + shape_vector[i] = tensor_accessor.at(idx); + i++; + } + auto out_shape = Shape(shape_vector); + auto outputs = createOp(inputs[0]->getOutput(0), out_shape); return outputs; } @@ -267,40 +304,39 @@ std::vector ONNXOpCreator::convertElementwise(InputOps& inputs, descriptors.push_back(input->getOutput(0)); return createOp(descriptors, op_type); } - std::vector ONNXOpCreator::convertBatchNorm(InputOps& inputs, - const onnx::NodeProto& onnx_node, - InputTensors& input_tensors) { + const onnx::NodeProto& onnx_node, + InputTensors& input_tensors) { + // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias bool found; float value; - std::tie(found, value) = getFloatAttribute(onnx_node, "epsilon"); - float epsilon = found ? value : 1e-05; - std::tie(found, value) = getFloatAttribute(onnx_node, "momentum"); - float momentum = found ? value : 0.9; - // FIXME: spatial vs. scale_factor - //std::tie(found, value) = getFloatAttribute(onnx_node, "spatial"); - float scale_factor = 0.0f; - // Scale tensor - assert(input_tensors.find(inputs[1]->getName()) != input_tensors.end()); - auto ptensor = input_tensors.at(inputs[1]->getName()); - Tensor nnc_scale(ptensor); - // Bias tensor - assert(input_tensors.find(inputs[2]->getName()) != input_tensors.end()); - auto nnc_bias = input_tensors.at(inputs[2]->getName()); - // TODO: there are 2 training tensors in the inputs + float epsilon = found ? value : 1e-05f; - inputs.resize(1); - auto mean_outputs = createOp(inputs[0]->getOutput(0), nnc_bias); + const auto& scale = input_tensors.at(inputs[1]->getName()); + const auto& bias = input_tensors.at(inputs[2]->getName()); + const auto& mean = input_tensors.at(inputs[3]->getName()); + const auto& var = input_tensors.at(inputs[4]->getName()); + + // res1 = X - mean + Tensor bias_data(mean); + for (auto& idx: ShapeRange(bias_data.getShape())) + bias_data.at(idx) *= -1; - // create scale argument from variance: - // multiply elements of variance by scaleFactor and - // normalize biased input using scale operation - for (Index idx : ShapeRange(nnc_scale.getShape())) - nnc_scale.at(idx) = 1.0f / std::sqrt(nnc_scale.at(idx) * scale_factor + epsilon); + auto data = convertONNXToMIR(inputs[0]->getOutput(0)); + auto bias_add_1 = createOp(data[0]->getOutput(0), mean); - auto variance_outputs = createOp(mean_outputs[0]->getOutput(0), ptensor); - return variance_outputs; + // res2 = res1 * scale / (var + epsilon) + Tensor multiplier(scale); + Tensor var_accessor(var); + for (auto& idx: ShapeRange(scale.getShape())) + multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon); + auto scale_op = createOp(bias_add_1[0]->getOutput(0), scale); + + // overall_res = res2 + bias + auto bias_add_2 = createOp(scale_op[0]->getOutput(0), bias); + + return {convertMIRToONNX(bias_add_2[0]->getOutput(0))}; } std::vector ONNXOpCreator::convertDropout(InputOps& inputs, @@ -318,7 +354,8 @@ std::vector ONNXOpCreator::convertScale(InputOps& inputs, float value; std::tie(found, value) = getFloatAttribute(onnx_node, "scale"); float scale = found ? value : 1.0; - auto outputs = createOp(inputs[0]->getOutput(0), createTensor(scale)); + auto outputs = createOp(inputs[0]->getOutput(0), + createTensor(scale, inputs[0]->getOutputShape(0))); return outputs; } @@ -328,28 +365,75 @@ std::vector ONNXOpCreator::convertGemm(InputOps& inputs, int ivalue; float fvalue; + // Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), + // input tensor B has shape (K, N) or (N, K), + // input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). + // A will be transposed before doing the computation if attribute transA is non-zero, + // same for B and transB. This operator supports unidirectional broadcasting + // (tensor C should be unidirectional broadcastable to tensor A * B). + std::tie (found, ivalue) = getIntAttribute(onnx_node, "transA"); - bool transA = found ? ivalue : 0; + bool trans_a = found ? ivalue : 0; std::tie (found, ivalue) = getIntAttribute(onnx_node, "transB"); - bool transB = found ? ivalue : 0; - std::tie (found, fvalue) = getIntAttribute(onnx_node, "alpha"); + bool trans_b = found ? ivalue : 0; + std::tie (found, fvalue) = getFloatAttribute(onnx_node, "alpha"); float alpha = found ? fvalue : 1.0; - std::tie (found, fvalue) = getIntAttribute(onnx_node, "beta"); + std::tie (found, fvalue) = getFloatAttribute(onnx_node, "beta"); float beta = found ? fvalue : 1.0; + // 1. Prepare input matrix A // Flatten the shape by dim(0) mir::Shape shape0 ({inputs[0]->getOutputShape(0).dim(0), inputs[0]->getOutputShape(0).numElements() / inputs[0]->getOutputShape(0).dim(0)}); - auto reshape = createOp(inputs[0]->getOutput(0), shape0); + auto input_a = createOp(inputs[0]->getOutput(0), shape0); + if (trans_a) + input_a = createOp(input_a[0]->getOutput(0), std::vector{1, 0}); + if (alpha != 1.0) + input_a = createOp(input_a[0]->getOutput(0), + createTensor(alpha, input_a[0]->getOutputShape(0))); + + // 2. Prepare input matrix B + // + auto input_b = inputs[1]->getOutput(0); + if (trans_b) + input_b = createOp(input_b, std::vector{1, 0})[0]->getOutput(0); + // Number of cols in tensor A must be equal to number of rows in tensor B + assert(input_a[0]->getOutput(0).op->getOutputShape(0).dim(1) == + input_b.op->getOutputShape(0).dim(0)); + Shape mult_a_b({input_a[0]->getOutput(0).op->getOutputShape(0).dim(0), + input_b.op->getOutputShape(0).dim(1)}); + + // 3. Prepare input matrix C + // + auto input_c = inputs[2]->getOutput(0); + auto beta_tensor = createTensor(beta, input_c.op->getOutputShape(0)); + if ((mult_a_b.rank() == 2) && (input_c.op->getOutputShape(0).rank() == 1)) { + beta_tensor = TensorVariant(beta_tensor, mult_a_b); + } + auto constant = createOp(beta_tensor)[0]->getOutput(0); + std::vector descriptors = {constant, input_c}; + auto c_mult = createOp(descriptors, ops::ElementwiseOp::OpType::mul); + assert(c_mult[0]->getOutput(0).op->getOutputShape(0) == mult_a_b); + return createOp(input_a[0]->getOutput(0), input_b, c_mult[0]->getOutput(0)); +} - std::vector descriptors; - descriptors.push_back(reshape[0]->getOutput(0)); - descriptors.push_back(inputs[1]->getOutput(0)); - descriptors.push_back(inputs[2]->getOutput(0)); +std::vector +ONNXOpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) { + // TODO For now we only support convolutional networks with one element per batch. + assert(input_shape.rank() == 4 && input_shape.dim(0) == 1); + auto variable = _graph->create(input_name, input_shape); + return {variable}; +} + +std::vector ONNXOpCreator::convertONNXToMIR(const mir::IODescriptor& arg) { + // NCHW -> NHWC + return createOp(arg, std::vector{0, 2, 3, 1}); +} - return createOp(reshape[0]->getOutput(0), inputs[1]->getOutput(0), - inputs[2]->getOutput(0), transA, transB, alpha, beta); +std::vector ONNXOpCreator::convertMIRToONNX(const mir::IODescriptor& arg) { + // NHWC -> NCHW + return createOp(arg, std::vector{0, 3, 1, 2}); } } // namespace nnc diff --git a/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h b/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h index 479492a..30999b5 100644 --- a/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h +++ b/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h @@ -43,7 +43,7 @@ public: std::vector convertPool(InputOps& inputs, ONNXOpCode op_code, const onnx::NodeProto& onnx_node); std::vector convertSoftmax(InputOps& inputs, const onnx::NodeProto& onnx_node); - std::vector convertReshape(mir::Operation* input_data, mir::Shape output_shape); + std::vector convertReshape(InputOps& inputs); std::vector convertRelu(InputOps& inputs); std::vector convertSigmoid(InputOps& inputs); @@ -58,6 +58,10 @@ public: std::vector convertGather(InputOps& inputs, const onnx::NodeProto& onnx_node); std::vector convertGemm(InputOps& inputs, const onnx::NodeProto& onnx_node); + std::vector createInput(const std::string&, const mir::Shape&); + std::vector convertONNXToMIR(const mir::IODescriptor& arg); + std::vector convertMIRToONNX(const mir::IODescriptor& arg); + private: template std::vector createOp(Types&&... args); diff --git a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp index db124f3..bd0c827 100644 --- a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp +++ b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp @@ -28,7 +28,6 @@ #include "core/modelIR/operations/BiasAddOp.h" #include "core/modelIR/operations/CappedReluOp.h" #include "core/modelIR/operations/ConcatOp.h" -#include "core/modelIR/operations/ConstantOp.h" #include "core/modelIR/operations/Conv2DOp.h" #include "core/modelIR/operations/Deconv2DOp.h" #include "core/modelIR/operations/DepthwiseConv2DOp.h" @@ -36,8 +35,8 @@ #include "core/modelIR/operations/ElementwiseOp.h" #include "core/modelIR/operations/EluOp.h" #include "core/modelIR/operations/FullyConnectedOp.h" -#include "core/modelIR/operations/GemmOp.h" #include "core/modelIR/operations/GatherOp.h" +#include "core/modelIR/operations/GemmOp.h" #include "core/modelIR/operations/PadOp.h" #include "core/modelIR/operations/PoolOp.h" #include "core/modelIR/operations/ReduceFOp.h" @@ -88,7 +87,7 @@ void ModelAnalyzer::addOpDescr(Operation* op, const string& opName) { nodeTid = allocateTensor(name, TensorDescription::Type::OUT); _named_tensors.push_back(nodeTid); type = OpDescr::Type::OUT; - } else { + } else { // process ordinary op nodeTid = allocateTensor(); } @@ -141,8 +140,6 @@ void ModelAnalyzer::analyze(const mir::Graph* g) { // Collect all inputs and constants vector init_ops(g->collectInputs()); - vector constant_ops(g->collectConstants()); - init_ops.insert(init_ops.end(), constant_ops.begin(), constant_ops.end()); // Walk all network inputs for (Operation* in : init_ops) { diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def new file mode 100644 index 0000000..1b9e094 --- /dev/null +++ b/contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +inline void gemm(const float* tensor_a_data, const Dims<4>& tensor_a_dims, + const float* tensor_b_data, const Dims<4>& tensor_b_dims, + const float* tensor_c_data, const Dims<4>& tensor_c_dims, + float* output_data, const Dims<4>& output_dims) { + const auto tensor_a_map = + MapAsMatrixWithFirstDimAsRows(tensor_a_data, tensor_a_dims); + const auto tensor_b_map = + MapAsMatrixWithFirstDimAsRows(tensor_b_data, tensor_b_dims); + const auto tensor_c_map = + MapAsMatrixWithFirstDimAsRows(tensor_c_data, tensor_c_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + Gemm(tensor_a_map, tensor_b_map, &output_matrix_map); + auto size = tensor_a_dims.sizes[0] * tensor_a_dims.sizes[1] * + tensor_a_dims.sizes[2] * tensor_a_dims.sizes[3]; + for (int i = 0; i < size; i++) { + output_data[i] = output_data[i] + tensor_c_data[i]; + } +} -- 2.7.4