From d4dc7c4898e5aadd762cd30b0b0c55ede98e2daf Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 6 Aug 2019 19:48:45 +0300 Subject: [PATCH] [mir_onnx] Add NodeConverters for all supported operations from OpCreator (#6291) * Added OpConverters * Register all converters * Remove OpCreator Signed-off-by: Pavel Iliutchenko --- compiler/mir-onnx-importer/CMakeLists.txt | 52 ++- compiler/mir-onnx-importer/ONNXOpCreator.cpp | 505 --------------------- compiler/mir-onnx-importer/ONNXOpCreator.h | 119 ----- compiler/mir-onnx-importer/ONNXOpRegistration.h | 52 ++- compiler/mir-onnx-importer/Op/Add.cpp | 36 ++ compiler/mir-onnx-importer/Op/Add.h | 30 ++ compiler/mir-onnx-importer/Op/AveragePool.cpp | 49 ++ compiler/mir-onnx-importer/Op/AveragePool.h | 30 ++ .../mir-onnx-importer/Op/BatchNormalization.cpp | 73 +++ compiler/mir-onnx-importer/Op/BatchNormalization.h | 30 ++ compiler/mir-onnx-importer/Op/Concat.cpp | 40 ++ compiler/mir-onnx-importer/Op/Concat.h | 30 ++ compiler/mir-onnx-importer/Op/Constant.cpp | 45 ++ compiler/mir-onnx-importer/Op/Constant.h | 30 ++ compiler/mir-onnx-importer/Op/Conv.cpp | 83 ++++ compiler/mir-onnx-importer/Op/Conv.h | 30 ++ compiler/mir-onnx-importer/Op/Dropout.cpp | 39 ++ compiler/mir-onnx-importer/Op/Dropout.h | 30 ++ compiler/mir-onnx-importer/Op/Gather.cpp | 39 ++ compiler/mir-onnx-importer/Op/Gather.h | 30 ++ compiler/mir-onnx-importer/Op/Gemm.cpp | 101 +++++ compiler/mir-onnx-importer/Op/Gemm.h | 30 ++ compiler/mir-onnx-importer/Op/GivenTensorFill.cpp | 47 ++ compiler/mir-onnx-importer/Op/GivenTensorFill.h | 30 ++ .../mir-onnx-importer/Op/GlobalAveragePool.cpp | 49 ++ compiler/mir-onnx-importer/Op/GlobalAveragePool.h | 30 ++ compiler/mir-onnx-importer/Op/Max.cpp | 36 ++ compiler/mir-onnx-importer/Op/Max.h | 30 ++ compiler/mir-onnx-importer/Op/MaxPool.cpp | 49 ++ compiler/mir-onnx-importer/Op/MaxPool.h | 30 ++ compiler/mir-onnx-importer/Op/Mul.cpp | 36 ++ compiler/mir-onnx-importer/Op/Mul.h | 30 ++ compiler/mir-onnx-importer/Op/Pad.cpp | 56 +++ compiler/mir-onnx-importer/Op/Pad.h | 30 ++ compiler/mir-onnx-importer/Op/Relu.cpp | 36 ++ compiler/mir-onnx-importer/Op/Relu.h | 30 ++ compiler/mir-onnx-importer/Op/Reshape.cpp | 67 +++ compiler/mir-onnx-importer/Op/Reshape.h | 30 ++ compiler/mir-onnx-importer/Op/Scale.cpp | 43 ++ compiler/mir-onnx-importer/Op/Scale.h | 30 ++ compiler/mir-onnx-importer/Op/Shape.cpp | 46 ++ compiler/mir-onnx-importer/Op/Shape.h | 30 ++ compiler/mir-onnx-importer/Op/Sigmoid.cpp | 36 ++ compiler/mir-onnx-importer/Op/Sigmoid.h | 30 ++ compiler/mir-onnx-importer/Op/Softmax.cpp | 39 ++ compiler/mir-onnx-importer/Op/Softmax.h | 30 ++ compiler/mir-onnx-importer/Op/Sum.cpp | 36 ++ compiler/mir-onnx-importer/Op/Sum.h | 30 ++ compiler/mir-onnx-importer/Op/Unsqueeze.cpp | 55 +++ compiler/mir-onnx-importer/Op/Unsqueeze.h | 30 ++ compiler/mir-onnx-importer/Op/Upsample.cpp | 62 +++ compiler/mir-onnx-importer/Op/Upsample.h | 30 ++ 52 files changed, 2018 insertions(+), 628 deletions(-) delete mode 100644 compiler/mir-onnx-importer/ONNXOpCreator.cpp delete mode 100644 compiler/mir-onnx-importer/ONNXOpCreator.h create mode 100644 compiler/mir-onnx-importer/Op/Add.cpp create mode 100644 compiler/mir-onnx-importer/Op/Add.h create mode 100644 compiler/mir-onnx-importer/Op/AveragePool.cpp create mode 100644 compiler/mir-onnx-importer/Op/AveragePool.h create mode 100644 compiler/mir-onnx-importer/Op/BatchNormalization.cpp create mode 100644 compiler/mir-onnx-importer/Op/BatchNormalization.h create mode 100644 compiler/mir-onnx-importer/Op/Concat.cpp create mode 100644 compiler/mir-onnx-importer/Op/Concat.h create mode 100644 compiler/mir-onnx-importer/Op/Constant.cpp create mode 100644 compiler/mir-onnx-importer/Op/Constant.h create mode 100644 compiler/mir-onnx-importer/Op/Conv.cpp create mode 100644 compiler/mir-onnx-importer/Op/Conv.h create mode 100644 compiler/mir-onnx-importer/Op/Dropout.cpp create mode 100644 compiler/mir-onnx-importer/Op/Dropout.h create mode 100644 compiler/mir-onnx-importer/Op/Gather.cpp create mode 100644 compiler/mir-onnx-importer/Op/Gather.h create mode 100644 compiler/mir-onnx-importer/Op/Gemm.cpp create mode 100644 compiler/mir-onnx-importer/Op/Gemm.h create mode 100644 compiler/mir-onnx-importer/Op/GivenTensorFill.cpp create mode 100644 compiler/mir-onnx-importer/Op/GivenTensorFill.h create mode 100644 compiler/mir-onnx-importer/Op/GlobalAveragePool.cpp create mode 100644 compiler/mir-onnx-importer/Op/GlobalAveragePool.h create mode 100644 compiler/mir-onnx-importer/Op/Max.cpp create mode 100644 compiler/mir-onnx-importer/Op/Max.h create mode 100644 compiler/mir-onnx-importer/Op/MaxPool.cpp create mode 100644 compiler/mir-onnx-importer/Op/MaxPool.h create mode 100644 compiler/mir-onnx-importer/Op/Mul.cpp create mode 100644 compiler/mir-onnx-importer/Op/Mul.h create mode 100644 compiler/mir-onnx-importer/Op/Pad.cpp create mode 100644 compiler/mir-onnx-importer/Op/Pad.h create mode 100644 compiler/mir-onnx-importer/Op/Relu.cpp create mode 100644 compiler/mir-onnx-importer/Op/Relu.h create mode 100644 compiler/mir-onnx-importer/Op/Reshape.cpp create mode 100644 compiler/mir-onnx-importer/Op/Reshape.h create mode 100644 compiler/mir-onnx-importer/Op/Scale.cpp create mode 100644 compiler/mir-onnx-importer/Op/Scale.h create mode 100644 compiler/mir-onnx-importer/Op/Shape.cpp create mode 100644 compiler/mir-onnx-importer/Op/Shape.h create mode 100644 compiler/mir-onnx-importer/Op/Sigmoid.cpp create mode 100644 compiler/mir-onnx-importer/Op/Sigmoid.h create mode 100644 compiler/mir-onnx-importer/Op/Softmax.cpp create mode 100644 compiler/mir-onnx-importer/Op/Softmax.h create mode 100644 compiler/mir-onnx-importer/Op/Sum.cpp create mode 100644 compiler/mir-onnx-importer/Op/Sum.h create mode 100644 compiler/mir-onnx-importer/Op/Unsqueeze.cpp create mode 100644 compiler/mir-onnx-importer/Op/Unsqueeze.h create mode 100644 compiler/mir-onnx-importer/Op/Upsample.cpp create mode 100644 compiler/mir-onnx-importer/Op/Upsample.h diff --git a/compiler/mir-onnx-importer/CMakeLists.txt b/compiler/mir-onnx-importer/CMakeLists.txt index 1f2d03e..4cc01d0 100644 --- a/compiler/mir-onnx-importer/CMakeLists.txt +++ b/compiler/mir-onnx-importer/CMakeLists.txt @@ -10,9 +10,55 @@ set(MIR_ONNX_IMPORTER_SOURCES ONNXImporterImpl.cpp ONNXImporterImpl.h ONNXNodeConverterRegistry.h - ONNXOpCreator.cpp - ONNXOpCreator.h - ONNXOpRegistration.h) + ONNXOpRegistration.h + Op/Add.cpp + Op/Add.h + Op/AveragePool.cpp + Op/AveragePool.h + Op/BatchNormalization.cpp + Op/BatchNormalization.h + Op/Concat.cpp + Op/Concat.h + Op/Constant.cpp + Op/Constant.h + Op/Conv.cpp + Op/Conv.h + Op/Dropout.cpp + Op/Dropout.h + Op/Gather.cpp + Op/Gather.h + Op/Gemm.cpp + Op/Gemm.h + Op/GivenTensorFill.cpp + Op/GivenTensorFill.h + Op/GlobalAveragePool.cpp + Op/GlobalAveragePool.h + Op/Max.cpp + Op/Max.h + Op/MaxPool.cpp + Op/MaxPool.h + Op/Mul.cpp + Op/Mul.h + Op/Pad.cpp + Op/Pad.h + Op/Relu.cpp + Op/Relu.h + Op/Reshape.cpp + Op/Reshape.h + Op/Scale.cpp + Op/Scale.h + Op/Shape.cpp + Op/Shape.h + Op/Sigmoid.cpp + Op/Sigmoid.h + Op/Softmax.cpp + Op/Softmax.h + Op/Sum.cpp + Op/Sum.h + Op/Unsqueeze.cpp + Op/Unsqueeze.h + Op/Upsample.cpp + Op/Upsample.h) add_library(mir_onnx_importer STATIC ${MIR_ONNX_IMPORTER_SOURCES}) set_target_properties(mir_onnx_importer PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/compiler/mir-onnx-importer/ONNXOpCreator.cpp b/compiler/mir-onnx-importer/ONNXOpCreator.cpp deleted file mode 100644 index b1f6046..0000000 --- a/compiler/mir-onnx-importer/ONNXOpCreator.cpp +++ /dev/null @@ -1,505 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ONNXOpCreator.h" -#include "ONNXHelpers.h" -#include "ONNXImporterImpl.h" - -#include "mir/ops/BatchNormOp.h" -#include "mir/ops/BiasAddOp.h" -#include "mir/ops/CappedReluOp.h" -#include "mir/ops/ConcatOp.h" -#include "mir/ops/ConstantOp.h" -#include "mir/ops/Conv2DOp.h" -#include "mir/ops/DepthwiseConv2DOp.h" -#include "mir/ops/DropoutOp.h" -#include "mir/ops/FullyConnectedOp.h" -#include "mir/ops/GatherOp.h" -#include "mir/ops/GemmOp.h" -#include "mir/ops/InputOp.h" -#include "mir/ops/PadOp.h" -#include "mir/ops/PoolOp.h" -#include "mir/ops/ReluOp.h" -#include "mir/ops/ReshapeOp.h" -#include "mir/ops/ResizeOp.h" -#include "mir/ops/ScaleOp.h" -#include "mir/ops/SigmoidOp.h" -#include "mir/ops/SoftmaxOp.h" -#include "mir/ops/TransposeOp.h" -#include "mir/ops/ElementwiseOp.h" -#include "mir/Index.h" -#include "mir/Graph.h" -#include "mir/Scalar.h" -#include "mir/ShapeRange.h" -#include "mir/Tensor.h" -#include "mir/TensorUtil.h" - -#include -#include -#include - -namespace mir_onnx -{ - -using namespace mir; - -std::vector -ONNXOpCreator::convertConv2D(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - assert(inputs.size() >= 2); - - KernelStridesPadding cdata; - getKernelStridesPadding(onnx_node, cdata); - // FIXME: It can be non-constant value. - auto *in_weights = dynamic_cast(inputs[1]->getNode()); - assert(in_weights && "Weights could be a constant tensor only"); - const auto &in_weights_tensor = in_weights->getValue(); - // We should transpose ONNX MC(IO)HW to HWOI - auto kernel_tensor = transposeTensor<2, 3, 1, 0>(in_weights_tensor); - auto in_group_size = kernel_tensor.getShape().dim(2); - auto out_channels = kernel_tensor.getShape().dim(3); - bool found; - int num_groups; - std::tie(found, num_groups) = getIntAttribute(onnx_node, "group"); - if (!found) - num_groups = 1; - bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups); - - mir::Operation *result; - auto transposed_input = convertONNXToMIR(_graph, inputs[0]); - if (is_depthwise) - { - // TODO handle properly kernel with layer multiplier - auto transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor); - auto kernel = createOp(_graph, transposed_tensor)->getOutput(0); - result = createOp(_graph, transposed_input, kernel, cdata.strides_shape, - cdata.padding_before, cdata.padding_after); - } - else - { - // first we need to convert kernel of grouped convolution to appropriate ordinary kernel - if (num_groups != 1) - kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor); - kernel_tensor = transposeTensor<3, 0, 1, 2>(kernel_tensor); - auto kernel = createOp(_graph, kernel_tensor)->getOutput(0); - result = createOp(_graph, transposed_input, kernel, cdata.strides_shape, - cdata.padding_before, cdata.padding_after); - } - - if (inputs.size() > 2) - result = createOp(_graph, result->getOutput(0), inputs[2]); - - return {convertMIRToONNX(_graph, result->getOutput(0))}; -} - -std::vector -ONNXOpCreator::convertConcat(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool found; - int axis; - std::tie(found, axis) = getIntAttribute(onnx_node); - if (!found) - throw std::runtime_error("Concat must have 'axis' attribute"); - auto result = createOp(_graph, inputs, axis); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertGather(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool found; - int value; - std::tie(found, value) = getIntAttribute(onnx_node, "axis"); - int axis = found ? value : 0; - auto result = createOp(_graph, inputs[0], inputs[1], axis); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertPad(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool found; - float value; - std::tie(found, value) = getFloatAttribute(onnx_node, "value"); - assert(found); - auto padsAtt = findAttribute(onnx_node, "pads"); - assert(padsAtt); - auto modeAtt = findAttribute(onnx_node, "mode"); - assert(modeAtt); - auto mode = modeAtt->s(); - const mir::Scalar scalar(reinterpret_cast(&value), DTYPE::FLOAT32, sizeof(float)); - assert(padsAtt->ints_size() > 0); - int cnt = padsAtt->ints_size() / 2; - assert(cnt % 2 == 0); - int last = padsAtt->ints_size() - 1; - std::vector> vec(cnt); - auto *data = padsAtt->ints().data(); - for (int i = 0; i < cnt; i++) - { - auto pair = std::make_pair(data[i], data[last - i]); - vec[i] = pair; - } - auto result = createOp(_graph, inputs[0], inputs[0]->getShape().rank(), vec, scalar); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertPool(const std::vector &inputs, - mir::ops::PoolOp::PoolingType pool_type, - const onnx::NodeProto &onnx_node) -{ - ops::PoolOp::BorderType border_type; - - KernelStridesPadding cdata; - // Transpose ONNX NCHW to MIR NHWC - auto t_input = convertONNXToMIR(_graph, inputs[0]); - - switch (pool_type) - { - case mir::ops::PoolOp::PoolingType::AVG: - border_type = ops::PoolOp::BorderType::ZEROFILLED; - pool_type = ops::PoolOp::PoolingType::AVG; - getKernelStridesPadding(onnx_node, cdata); - break; - case mir::ops::PoolOp::PoolingType::MAX: - border_type = ops::PoolOp::BorderType::EMPTY; - pool_type = ops::PoolOp::PoolingType::MAX; - getKernelStridesPadding(onnx_node, cdata); - break; - default: - assert(false); - } - auto result = - createOp(_graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape, - cdata.padding_before, cdata.padding_after, border_type); - return {convertMIRToONNX(_graph, result->getOutput(0))}; -} - -std::vector -ONNXOpCreator::convertSoftmax(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - int axis; - bool found; - std::tie(found, axis) = getIntAttribute(onnx_node); - axis = found ? axis : 1; - auto result = createOp(_graph, inputs[0], axis); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertReshape(const std::vector &inputs) -{ - // The original shape - const auto &in_shape = inputs[0]->getShape(); - - // Input tensor describing the new shape - // TODO: could it be not a constant? - auto *op = dynamic_cast(inputs[1]->getNode()); - assert(op && "We support constants only"); - auto shape_tensor = op->getValue(); - Shape shape_tensor_shape = (shape_tensor).getShape(); - assert(shape_tensor_shape.rank() == 1); - // The rank of the new shape - auto cnt = shape_tensor_shape.numElements(); - // The vector to build the new shape from - std::vector shape_vector(cnt); - ShapeRange out_range(shape_tensor_shape); - Tensor tensor_accessor(shape_tensor); - - int i = 0; - for (auto idx : out_range) - { - if (tensor_accessor.at(idx) == 0) - shape_vector[i] = in_shape.dim(i); - else if (tensor_accessor.at(idx) == -1) - shape_vector[i] = Shape::autoDim; - else - shape_vector[i] = tensor_accessor.at(idx); - i++; - } - auto out_shape = Shape(shape_vector); - auto result = createOp(_graph, inputs[0], out_shape); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertUnsqueeze(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - auto *axes = findAttribute(onnx_node, "axes"); - assert(axes && axes->ints_size()); - const Shape &input_shape = inputs[0]->getShape(); - const int out_rank = input_shape.rank() + axes->ints_size(); - Shape out_shape(out_rank); - auto ints_iterator = axes->ints().begin(); - int j = 0; - for (int i = 0; i < out_rank; i++) - { - if (ints_iterator < axes->ints().end() && i == *ints_iterator) - { - out_shape.dim(i) = 1; - ints_iterator++; - } - else - { - out_shape.dim(i) = input_shape.dim(j); - j++; - } - } - auto result = createOp(_graph, inputs[0], out_shape); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertRelu(const std::vector &inputs) -{ - assert(inputs.size() == 1); - auto result = createOp(_graph, inputs[0]); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertSigmoid(const std::vector &inputs) -{ - assert(inputs.size() == 1); - auto result = createOp(_graph, inputs[0]); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertElementwise(const std::vector &inputs, - mir::ops::ElementwiseOp::OpType op_type) -{ - auto result = createOp(_graph, inputs, op_type); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertUpsample(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool success; - std::string mode; - std::tie(success, mode) = getStringAttribute(onnx_node, "mode"); - if (!success) - mode = "nearest"; - assert(mode == "nearest" && "Unsupported upscale mode!"); - - // relies on attributes being lifted to constants (ONNX optimization pass) - assert(inputs.size() > 1); - auto *scales = dynamic_cast(inputs[1]->getNode()); - assert(scales && "Weights could be a constant tensor only"); - auto scales_tensor = Tensor(scales->getValue()); - int rank = inputs[0]->getShape().rank(); - assert(scales_tensor.getShape().numElements() == rank && - "The number of elements of 'scales' should be the same as the rank of input 'X'"); - assert(rank == 4 && "Only rank 4 is supported"); - std::vector scales_vector(4); - const int onnx2mir[] = {0, 3, 1, 2}; - assert(scales_tensor.getShape().rank() == 1 && "Scales are a 1d tensor"); - for (int i = 0; i < scales_tensor.getShape().numElements(); i++) - scales_vector[onnx2mir[i]] = scales_tensor.atOffset(i); - return {convertMIRToONNX( - _graph, createOp(_graph, convertONNXToMIR(_graph, inputs[0]), - ops::ResizeOp::ResizeMethod::nearestNeighbor, scales_vector) - ->getOutput(0))}; -} - -std::vector -ONNXOpCreator::convertBatchNorm(const std::vector &inputs, - const onnx::NodeProto &onnx_node, InputTensors &input_tensors) -{ - // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias - bool found; - float value; - std::tie(found, value) = getFloatAttribute(onnx_node, "epsilon"); - float epsilon = found ? value : 1e-05f; - - // TODO: it's better to do it via inputs - const auto &scale_tensor = input_tensors.at(inputs[1]->getNode()->getName()); - const auto &bias_tensor = input_tensors.at(inputs[2]->getNode()->getName()); - const auto &mean_tensor = input_tensors.at(inputs[3]->getNode()->getName()); - const auto &var_tensor = input_tensors.at(inputs[4]->getNode()->getName()); - - // res1 = X - mean - Tensor bias_data(mean_tensor); - for (auto &idx : ShapeRange(bias_data.getShape())) - bias_data.at(idx) *= -1; - - auto data = convertONNXToMIR(_graph, inputs[0]); - auto mean = createOp(_graph, mean_tensor)->getOutput(0); - auto result = createOp(_graph, data, mean); - - // res2 = res1 * scale / (var + epsilon) - Tensor multiplier(scale_tensor); - Tensor var_accessor(var_tensor); - for (auto &idx : ShapeRange(scale_tensor.getShape())) - multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon); - auto scale = createOp(_graph, scale_tensor)->getOutput(0); - result = createOp(_graph, result->getOutput(0), scale); - - // overall_res = res2 + bias - auto bias = createOp(_graph, bias_tensor)->getOutput(0); - result = createOp(_graph, result->getOutput(0), bias); - - return {convertMIRToONNX(_graph, result->getOutput(0))}; -} - -std::vector -ONNXOpCreator::convertDropout(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool found; - float value; - std::tie(found, value) = getFloatAttribute(onnx_node, "ratio"); - float ratio = found ? value : 1.0; - auto result = createOp(_graph, inputs[0], ratio); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertScale(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool found; - float value; - std::tie(found, value) = getFloatAttribute(onnx_node, "scale"); - float scale_val = found ? value : 1.0; - const auto &shape = inputs[0]->getShape(); - auto scale_tensor = createScalarTensor(scale_val, shape); - auto scale = createOp(_graph, scale_tensor)->getOutput(0); - auto result = createOp(_graph, inputs[0], scale); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertShape(const std::vector &inputs) -{ - const auto &input_shape = inputs[0]->getShape(); - int size = input_shape.rank(); - Shape output_shape{size}; - std::vector data(static_cast(size)); - for (int i = 0; i < size; i++) - { - data[i] = input_shape.dim(i); - } - TensorVariant tensor(DTYPE::FLOAT32, output_shape, data.data()); - auto result = createOp(_graph, tensor); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertGivenTensorFill(const onnx::NodeProto &onnx_node, InputTensors &input_tensors) -{ - auto values_att = findAttribute(onnx_node, "values"); - auto shape_att = findAttribute(onnx_node, "shape"); - assert(values_att && shape_att); - assert(values_att->floats_size() > 0 && shape_att->ints_size() > 0); - Shape shape(shape_att->ints_size()); - for (int i = 0; i < shape_att->ints_size(); i++) - shape.dim(i) = shape_att->ints(i); - TensorVariant tensor(DTYPE::FLOAT32, shape, values_att->floats().data()); - input_tensors.insert(std::make_pair(onnx_node.output(0), tensor)); - auto result = createOp(_graph, tensor); - return {result->getOutput(0)}; -} - -std::vector -ONNXOpCreator::convertConstant(const onnx::NodeProto &onnx_node, InputTensors &input_tensors) -{ - assert((onnx_node.attribute_size() == 1) && - (onnx_node.attribute(0).type() == onnx::AttributeProto_AttributeType_TENSOR) && - (onnx_node.attribute(0).tensors_size() == 0)); - assert(onnx_node.attribute(0).name() == "value"); - auto name = onnx_node.output(0); - auto &onnx_tensor = onnx_node.attribute(0).t(); - auto mir_tensor = createTensor(&onnx_tensor); - input_tensors.insert(std::make_pair(name, mir_tensor)); - auto op = _graph->create(name, mir_tensor)->getOutput(0); - return {op}; -} - -std::vector -ONNXOpCreator::convertGemm(const std::vector &inputs, - const onnx::NodeProto &onnx_node) -{ - bool found; - int ivalue; - float fvalue; - - // Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), - // input tensor B has shape (K, N) or (N, K), - // input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). - // A will be transposed before doing the computation if attribute transA is non-zero, - // same for B and transB. This operator supports unidirectional broadcasting - // (tensor C should be unidirectional broadcastable to tensor A * B). - - std::tie(found, ivalue) = getIntAttribute(onnx_node, "transA"); - bool trans_a = found ? static_cast(ivalue) : false; - std::tie(found, ivalue) = getIntAttribute(onnx_node, "transB"); - bool trans_b = found ? static_cast(ivalue) : false; - std::tie(found, fvalue) = getFloatAttribute(onnx_node, "alpha"); - float alpha_val = found ? fvalue : 1.0f; - std::tie(found, fvalue) = getFloatAttribute(onnx_node, "beta"); - float beta_val = found ? fvalue : 1.0f; - - // 1. Prepare input matrix A - // Flatten the shape by dim(0) - const auto &in_shape = inputs[0]->getShape(); - mir::Shape shape0{in_shape.dim(0), in_shape.numElements() / in_shape.dim(0)}; - auto input_a = createOp(_graph, inputs[0], shape0)->getOutput(0); - if (trans_a) - input_a = - createOp(_graph, input_a, std::vector{1, 0})->getOutput(0); - if (alpha_val != 1.0) - { - auto alpha_tensor = createScalarTensor(alpha_val, input_a->getShape()); - auto alpha = createOp(_graph, alpha_tensor)->getOutput(0); - input_a = createOp(_graph, input_a, alpha)->getOutput(0); - } - - // 2. Prepare input matrix B - // - auto input_b = inputs[1]; - if (trans_b) - input_b = - createOp(_graph, input_b, std::vector{1, 0})->getOutput(0); - // Number of cols in tensor A must be equal to number of rows in tensor B - assert(input_a->getShape().dim(1) == input_b->getShape().dim(0)); - Shape mult_a_b{input_a->getShape().dim(0), input_b->getShape().dim(1)}; - - // 3. Prepare input matrix C - // - auto input_c = inputs[2]; - auto beta_tensor = createScalarTensor(beta_val, input_c->getShape()); - if ((mult_a_b.rank() == 2) && (input_c->getShape().rank() == 1)) - { - beta_tensor = TensorVariant(beta_tensor, mult_a_b); - } - auto beta = createOp(_graph, beta_tensor)->getOutput(0); - std::vector mul_inputs = {beta, input_c}; - auto c_mult = createOp(_graph, mul_inputs, ops::ElementwiseOp::OpType::mul) - ->getOutput(0); - assert(c_mult->getShape() == mult_a_b); - auto result = createOp(_graph, input_a, input_b, c_mult); - return {result->getOutput(0)}; -} -} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/ONNXOpCreator.h b/compiler/mir-onnx-importer/ONNXOpCreator.h deleted file mode 100644 index d6caa99..0000000 --- a/compiler/mir-onnx-importer/ONNXOpCreator.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _MIR_ONNX_OP_CREATOR_H -#define _MIR_ONNX_OP_CREATOR_H - -#include "onnx/onnx.pb.h" - -#include "mir/ops/CommonProps.h" -#include "mir/ops/ElementwiseOp.h" -#include "mir/ops/PoolOp.h" -#include "mir/Graph.h" -#include "mir/Shape.h" -#include "mir/TensorVariant.h" - -#include -#include -#include -#include - -namespace mir_onnx -{ - -class ONNXOpCreator -{ -public: - using InputTensors = std::map; - - ONNXOpCreator() = default; - - void setMirGraph(mir::Graph *g) { _graph = g; }; - - std::vector - convertConv2D(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertConcat(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector convertGivenTensorFill(const onnx::NodeProto &onnx_node, - InputTensors &input_tensors); - - std::vector convertConstant(const onnx::NodeProto &onnx_node, - InputTensors &input_tensors); - - std::vector - convertPool(const std::vector &inputs, - mir::ops::PoolOp::PoolingType pool_type, const onnx::NodeProto &onnx_node); - - std::vector - convertPad(const std::vector &inputs, const onnx::NodeProto &onnx_node); - - std::vector - convertSoftmax(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertReshape(const std::vector &inputs); - - std::vector - convertRelu(const std::vector &inputs); - - std::vector - convertSigmoid(const std::vector &inputs); - - std::vector - convertUnsqueeze(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertUpsample(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertElementwise(const std::vector &inputs, - mir::ops::ElementwiseOp::OpType op_type); - - std::vector - convertScale(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertShape(const std::vector &inputs); - - std::vector - convertBatchNorm(const std::vector &inputs, - const onnx::NodeProto &onnx_node, InputTensors &input_tensors); - - std::vector - convertDropout(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertGather(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - - std::vector - convertGemm(const std::vector &inputs, - const onnx::NodeProto &onnx_node); - -private: - mir::Graph *_graph = nullptr; -}; -} // namespace mir_onnx -#endif // _MIR_ONNX_OP_CREATOR_H diff --git a/compiler/mir-onnx-importer/ONNXOpRegistration.h b/compiler/mir-onnx-importer/ONNXOpRegistration.h index 33ae841..1028478 100644 --- a/compiler/mir-onnx-importer/ONNXOpRegistration.h +++ b/compiler/mir-onnx-importer/ONNXOpRegistration.h @@ -16,13 +16,63 @@ #include "ONNXNodeConverterRegistry.h" +#include "Op/Add.h" +#include "Op/AveragePool.h" +#include "Op/BatchNormalization.h" +#include "Op/Concat.h" +#include "Op/Constant.h" +#include "Op/Conv.h" +#include "Op/Dropout.h" +#include "Op/Gather.h" +#include "Op/Gemm.h" +#include "Op/GivenTensorFill.h" +#include "Op/GlobalAveragePool.h" +#include "Op/Max.h" +#include "Op/MaxPool.h" +#include "Op/Mul.h" +#include "Op/Pad.h" +#include "Op/Relu.h" +#include "Op/Reshape.h" +#include "Op/Scale.h" +#include "Op/Shape.h" +#include "Op/Sigmoid.h" +#include "Op/Softmax.h" +#include "Op/Sum.h" +#include "Op/Unsqueeze.h" +#include "Op/Upsample.h" + namespace mir_onnx { inline void registerSupportedOps() { auto ®istry = NodeConverterRegistry::getInstance(); - // registry.registerConverter("Add", stdex::make_unique()); + registry.registerConverter("Add", stdex::make_unique()); + registry.registerConverter("AveragePool", stdex::make_unique()); + registry.registerConverter("BatchNormalization", + stdex::make_unique()); + registry.registerConverter("Concat", stdex::make_unique()); + registry.registerConverter("Constant", stdex::make_unique()); + registry.registerConverter("Conv", stdex::make_unique()); + registry.registerConverter("Dropout", stdex::make_unique()); + registry.registerConverter("Gather", stdex::make_unique()); + registry.registerConverter("Gemm", stdex::make_unique()); + registry.registerConverter("GivenTensorFill", stdex::make_unique()); + registry.registerConverter("GlobalAveragePool", + stdex::make_unique()); + registry.registerConverter("Max", stdex::make_unique()); + registry.registerConverter("MaxPool", stdex::make_unique()); + registry.registerConverter("Mul", stdex::make_unique()); + registry.registerConverter("Pad", stdex::make_unique()); + registry.registerConverter("Relu", stdex::make_unique()); + registry.registerConverter("Reshape", stdex::make_unique()); + registry.registerConverter("Scale", stdex::make_unique()); + registry.registerConverter("Shape", stdex::make_unique()); + registry.registerConverter("Sigmoid", stdex::make_unique()); + registry.registerConverter("Softmax", stdex::make_unique()); + registry.registerConverter("Sum", stdex::make_unique()); + registry.registerConverter("Unsqueeze", stdex::make_unique()); + registry.registerConverter("Upsample", stdex::make_unique()); } } // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Add.cpp b/compiler/mir-onnx-importer/Op/Add.cpp new file mode 100644 index 0000000..ed413d8 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Add.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Add.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ElementwiseOp.h" + +namespace mir_onnx +{ + +std::vector +AddNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + auto result = + createOp(graph, inputs, mir::ops::ElementwiseOp::OpType::add); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Add.h b/compiler/mir-onnx-importer/Op/Add.h new file mode 100644 index 0000000..67fb125 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Add.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class AddNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/AveragePool.cpp b/compiler/mir-onnx-importer/Op/AveragePool.cpp new file mode 100644 index 0000000..9aac33e --- /dev/null +++ b/compiler/mir-onnx-importer/Op/AveragePool.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "AveragePool.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/PoolOp.h" + +namespace mir_onnx +{ + +std::vector +AveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + // TODO Set some asserts + mir::ops::PoolOp::BorderType border_type; + mir::ops::PoolOp::PoolingType pool_type; + + KernelStridesPadding cdata; + // Transpose ONNX NCHW to MIR NHWC + auto t_input = convertONNXToMIR(graph, inputs[0]); + + border_type = mir::ops::PoolOp::BorderType::ZEROFILLED; + pool_type = mir::ops::PoolOp::PoolingType::AVG; + getKernelStridesPadding(onnx_node, cdata); + + auto result = + createOp(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape, + cdata.padding_before, cdata.padding_after, border_type); + return {convertMIRToONNX(graph, result->getOutput(0))}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/AveragePool.h b/compiler/mir-onnx-importer/Op/AveragePool.h new file mode 100644 index 0000000..f282281 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/AveragePool.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class AveragePoolNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/BatchNormalization.cpp b/compiler/mir-onnx-importer/Op/BatchNormalization.cpp new file mode 100644 index 0000000..5fb587b --- /dev/null +++ b/compiler/mir-onnx-importer/Op/BatchNormalization.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BatchNormalization.h" + +#include "ONNXHelpers.h" + +#include "mir/ShapeRange.h" +#include "mir/Tensor.h" + +#include "mir/ops/BiasAddOp.h" +#include "mir/ops/ConstantOp.h" +#include "mir/ops/ScaleOp.h" + +#include + +namespace mir_onnx +{ + +std::vector +BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias + bool found; + float value; + std::tie(found, value) = getFloatAttribute(onnx_node, "epsilon"); + float epsilon = found ? value : 1e-05f; + + const auto &scale_tensor = dynamic_cast(inputs[1]->getNode())->getValue(); + const auto &bias_tensor = dynamic_cast(inputs[2]->getNode())->getValue(); + const auto &mean_tensor = dynamic_cast(inputs[3]->getNode())->getValue(); + const auto &var_tensor = dynamic_cast(inputs[4]->getNode())->getValue(); + + // res1 = X - mean + mir::Tensor bias_data(mean_tensor); + for (auto &idx : mir::ShapeRange(bias_data.getShape())) + bias_data.at(idx) *= -1; + + auto data = convertONNXToMIR(graph, inputs[0]); + auto mean = createOp(graph, mean_tensor)->getOutput(0); + auto result = createOp(graph, data, mean); + + // res2 = res1 * scale / (var + epsilon) + mir::Tensor multiplier(scale_tensor); + mir::Tensor var_accessor(var_tensor); + for (auto &idx : mir::ShapeRange(scale_tensor.getShape())) + multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon); + auto scale = createOp(graph, scale_tensor)->getOutput(0); + result = createOp(graph, result->getOutput(0), scale); + + // overall_res = res2 + bias + auto bias = createOp(graph, bias_tensor)->getOutput(0); + result = createOp(graph, result->getOutput(0), bias); + + return {convertMIRToONNX(graph, result->getOutput(0))}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/BatchNormalization.h b/compiler/mir-onnx-importer/Op/BatchNormalization.h new file mode 100644 index 0000000..79eb8e3 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/BatchNormalization.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class BatchNormalizationNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Concat.cpp b/compiler/mir-onnx-importer/Op/Concat.cpp new file mode 100644 index 0000000..cfeb854 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Concat.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Concat.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ConcatOp.h" + +namespace mir_onnx +{ + +std::vector +ConcatNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool found; + int axis; + std::tie(found, axis) = getIntAttribute(onnx_node); + if (!found) + throw std::runtime_error("Concat must have 'axis' attribute"); + auto result = createOp(graph, inputs, axis); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Concat.h b/compiler/mir-onnx-importer/Op/Concat.h new file mode 100644 index 0000000..06af0b2 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Concat.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ConcatNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Constant.cpp b/compiler/mir-onnx-importer/Op/Constant.cpp new file mode 100644 index 0000000..0a1484b --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Constant.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Constant.h" + +#include "ONNXHelpers.h" + +#include "mir/TensorVariant.h" +#include "mir/ops/ConstantOp.h" + +namespace mir_onnx +{ + +std::vector +ConstantNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + assert((onnx_node.attribute_size() == 1) && + (onnx_node.attribute(0).type() == ::onnx::AttributeProto_AttributeType_TENSOR) && + (onnx_node.attribute(0).tensors_size() == 0)); + assert(onnx_node.attribute(0).name() == "value"); + auto name = onnx_node.output(0); + auto &onnx_tensor = onnx_node.attribute(0).t(); + auto mir_tensor = createTensor(&onnx_tensor); + // TODO check right removing input_tensors + // input_tensors.insert(std::make_pair(name, mir_tensor)); + auto op = graph->create(name, mir_tensor); + return {op->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Constant.h b/compiler/mir-onnx-importer/Op/Constant.h new file mode 100644 index 0000000..92aa6e6 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Constant.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ConstantNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Conv.cpp b/compiler/mir-onnx-importer/Op/Conv.cpp new file mode 100644 index 0000000..879018f --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Conv.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Conv.h" + +#include "ONNXHelpers.h" + +#include "mir/TensorUtil.h" + +#include "mir/ops/BiasAddOp.h" +#include "mir/ops/ConstantOp.h" +#include "mir/ops/Conv2DOp.h" +#include "mir/ops/DepthwiseConv2DOp.h" + +namespace mir_onnx +{ + +std::vector +ConvNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + assert(inputs.size() >= 2); + + KernelStridesPadding cdata; + getKernelStridesPadding(onnx_node, cdata); + // FIXME: It can be non-constant value. + auto *in_weights = dynamic_cast(inputs[1]->getNode()); + assert(in_weights && "Weights could be a constant tensor only"); + const auto &in_weights_tensor = in_weights->getValue(); + // We should transpose ONNX MC(IO)HW to HWOI + auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(in_weights_tensor); + auto in_group_size = kernel_tensor.getShape().dim(2); + auto out_channels = kernel_tensor.getShape().dim(3); + bool found; + int num_groups; + std::tie(found, num_groups) = getIntAttribute(onnx_node, "group"); + if (!found) + num_groups = 1; + bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups); + + mir::Operation *result; + auto transposed_input = convertONNXToMIR(graph, inputs[0]); + if (is_depthwise) + { + // TODO handle properly kernel with layer multiplier + auto transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor); + auto kernel = createOp(graph, transposed_tensor)->getOutput(0); + result = + createOp(graph, transposed_input, kernel, cdata.strides_shape, + cdata.padding_before, cdata.padding_after); + } + else + { + // first we need to convert kernel of grouped convolution to appropriate ordinary kernel + if (num_groups != 1) + kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor); + kernel_tensor = mir::transposeTensor<3, 0, 1, 2>(kernel_tensor); + auto kernel = createOp(graph, kernel_tensor)->getOutput(0); + result = createOp(graph, transposed_input, kernel, cdata.strides_shape, + cdata.padding_before, cdata.padding_after); + } + + if (inputs.size() > 2) + result = createOp(graph, result->getOutput(0), inputs[2]); + + return {convertMIRToONNX(graph, result->getOutput(0))}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Conv.h b/compiler/mir-onnx-importer/Op/Conv.h new file mode 100644 index 0000000..5849e65 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Conv.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ConvNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Dropout.cpp b/compiler/mir-onnx-importer/Op/Dropout.cpp new file mode 100644 index 0000000..e30923e --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Dropout.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dropout.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/DropoutOp.h" + +namespace mir_onnx +{ + +std::vector +DropoutNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool found; + float value; + std::tie(found, value) = getFloatAttribute(onnx_node, "ratio"); + float ratio = found ? value : 0.5; // default 0.5 + auto result = createOp(graph, inputs[0], ratio); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Dropout.h b/compiler/mir-onnx-importer/Op/Dropout.h new file mode 100644 index 0000000..f57fa49 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Dropout.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class DropoutNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Gather.cpp b/compiler/mir-onnx-importer/Op/Gather.cpp new file mode 100644 index 0000000..7e82b6a --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Gather.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Gather.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/GatherOp.h" + +namespace mir_onnx +{ + +std::vector +GatherNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool found; + int value; + std::tie(found, value) = getIntAttribute(onnx_node, "axis"); + int axis = found ? value : 0; + auto result = createOp(graph, inputs[0], inputs[1], axis); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Gather.h b/compiler/mir-onnx-importer/Op/Gather.h new file mode 100644 index 0000000..64c770a --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Gather.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class GatherNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Gemm.cpp b/compiler/mir-onnx-importer/Op/Gemm.cpp new file mode 100644 index 0000000..d1b6adf --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Gemm.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Gemm.h" + +#include "ONNXHelpers.h" + +#include "mir/TensorVariant.h" + +#include "mir/ops/ConstantOp.h" +#include "mir/ops/ElementwiseOp.h" +#include "mir/ops/GemmOp.h" +#include "mir/ops/ReshapeOp.h" +#include "mir/ops/ScaleOp.h" +#include "mir/ops/TransposeOp.h" + +namespace mir_onnx +{ + +std::vector +GemmNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool found; + int ivalue; + float fvalue; + + // Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), + // input tensor B has shape (K, N) or (N, K), + // input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). + // A will be transposed before doing the computation if attribute transA is non-zero, + // same for B and transB. This operator supports unidirectional broadcasting + // (tensor C should be unidirectional broadcastable to tensor A * B). + + std::tie(found, ivalue) = getIntAttribute(onnx_node, "transA"); + bool trans_a = found ? static_cast(ivalue) : false; + std::tie(found, ivalue) = getIntAttribute(onnx_node, "transB"); + bool trans_b = found ? static_cast(ivalue) : false; + std::tie(found, fvalue) = getFloatAttribute(onnx_node, "alpha"); + float alpha_val = found ? fvalue : 1.0f; + std::tie(found, fvalue) = getFloatAttribute(onnx_node, "beta"); + float beta_val = found ? fvalue : 1.0f; + + // 1. Prepare input matrix A + // Flatten the shape by dim(0) + const auto &in_shape = inputs[0]->getShape(); + mir::Shape shape0{in_shape.dim(0), in_shape.numElements() / in_shape.dim(0)}; + auto input_a = createOp(graph, inputs[0], shape0)->getOutput(0); + if (trans_a) + input_a = createOp(graph, input_a, std::vector{1, 0}) + ->getOutput(0); + if (alpha_val != 1.0) + { + auto alpha_tensor = createScalarTensor(alpha_val, input_a->getShape()); + auto alpha = createOp(graph, alpha_tensor)->getOutput(0); + input_a = createOp(graph, input_a, alpha)->getOutput(0); + } + + // 2. Prepare input matrix B + // + auto input_b = inputs[1]; + if (trans_b) + input_b = createOp(graph, input_b, std::vector{1, 0}) + ->getOutput(0); + // Number of cols in tensor A must be equal to number of rows in tensor B + assert(input_a->getShape().dim(1) == input_b->getShape().dim(0)); + mir::Shape mult_a_b{input_a->getShape().dim(0), input_b->getShape().dim(1)}; + + // 3. Prepare input matrix C + // + auto input_c = inputs[2]; + auto beta_tensor = createScalarTensor(beta_val, input_c->getShape()); + if ((mult_a_b.rank() == 2) && (input_c->getShape().rank() == 1)) + { + beta_tensor = mir::TensorVariant(beta_tensor, mult_a_b); + } + auto beta = createOp(graph, beta_tensor)->getOutput(0); + std::vector mul_inputs = {beta, input_c}; + auto c_mult = + createOp(graph, mul_inputs, mir::ops::ElementwiseOp::OpType::mul) + ->getOutput(0); + assert(c_mult->getShape() == mult_a_b); + auto result = createOp(graph, input_a, input_b, c_mult); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Gemm.h b/compiler/mir-onnx-importer/Op/Gemm.h new file mode 100644 index 0000000..461ebfd --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Gemm.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class GemmNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/GivenTensorFill.cpp b/compiler/mir-onnx-importer/Op/GivenTensorFill.cpp new file mode 100644 index 0000000..c3febc1 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/GivenTensorFill.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GivenTensorFill.h" + +#include "ONNXHelpers.h" + +#include "mir/TensorVariant.h" + +#include "mir/ops/ConstantOp.h" + +namespace mir_onnx +{ + +std::vector +GivenTensorFillNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + auto values_att = findAttribute(onnx_node, "values"); + auto shape_att = findAttribute(onnx_node, "shape"); + assert(values_att && shape_att); + assert(values_att->floats_size() > 0 && shape_att->ints_size() > 0); + mir::Shape shape(shape_att->ints_size()); + for (int i = 0; i < shape_att->ints_size(); i++) + shape.dim(i) = shape_att->ints(i); + mir::TensorVariant tensor(mir::DTYPE::FLOAT32, shape, values_att->floats().data()); + // TODO Check right removing input_tensors + // input_tensors.insert(std::make_pair(onnx_node.output(0), tensor)); + auto result = createOp(graph, tensor); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/GivenTensorFill.h b/compiler/mir-onnx-importer/Op/GivenTensorFill.h new file mode 100644 index 0000000..806a4b6 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/GivenTensorFill.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class GivenTensorFillNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/GlobalAveragePool.cpp b/compiler/mir-onnx-importer/Op/GlobalAveragePool.cpp new file mode 100644 index 0000000..d4f9736 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/GlobalAveragePool.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GlobalAveragePool.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/PoolOp.h" + +namespace mir_onnx +{ + +std::vector +GlobalAveragePoolNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + mir::ops::PoolOp::BorderType border_type = mir::ops::PoolOp::BorderType::ZEROFILLED; + mir::ops::PoolOp::PoolingType pool_type = mir::ops::PoolOp::PoolingType::AVG; + + KernelStridesPadding cdata; + // Transpose ONNX NCHW to MIR NHWC + auto t_input = convertONNXToMIR(graph, inputs[0]); + + // GlobalAveragePool is equivalent to AveragePool with kernel size equal + // to the spatial dimension of input tensor + cdata.kernel_shape = {t_input->getShape().dim(1), t_input->getShape().dim(2)}; + cdata.strides_shape = {1, 1}; + + auto result = + createOp(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape, + cdata.padding_before, cdata.padding_after, border_type); + return {convertMIRToONNX(graph, result->getOutput(0))}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/GlobalAveragePool.h b/compiler/mir-onnx-importer/Op/GlobalAveragePool.h new file mode 100644 index 0000000..48cfa8e --- /dev/null +++ b/compiler/mir-onnx-importer/Op/GlobalAveragePool.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class GlobalAveragePoolNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Max.cpp b/compiler/mir-onnx-importer/Op/Max.cpp new file mode 100644 index 0000000..b32ca43 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Max.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Max.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ElementwiseOp.h" + +namespace mir_onnx +{ + +std::vector +MaxNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + auto result = + createOp(graph, inputs, mir::ops::ElementwiseOp::OpType::max); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Max.h b/compiler/mir-onnx-importer/Op/Max.h new file mode 100644 index 0000000..80797b6 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Max.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class MaxNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/MaxPool.cpp b/compiler/mir-onnx-importer/Op/MaxPool.cpp new file mode 100644 index 0000000..4d415e8 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/MaxPool.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MaxPool.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/PoolOp.h" + +namespace mir_onnx +{ + +std::vector +MaxPoolNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + // TODO Set some asserts + mir::ops::PoolOp::BorderType border_type; + mir::ops::PoolOp::PoolingType pool_type; + + KernelStridesPadding cdata; + // Transpose ONNX NCHW to MIR NHWC + auto t_input = convertONNXToMIR(graph, inputs[0]); + + border_type = mir::ops::PoolOp::BorderType::EMPTY; + pool_type = mir::ops::PoolOp::PoolingType::MAX; + getKernelStridesPadding(onnx_node, cdata); + + auto result = + createOp(graph, t_input, pool_type, cdata.kernel_shape, cdata.strides_shape, + cdata.padding_before, cdata.padding_after, border_type); + return {convertMIRToONNX(graph, result->getOutput(0))}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/MaxPool.h b/compiler/mir-onnx-importer/Op/MaxPool.h new file mode 100644 index 0000000..cf7058b --- /dev/null +++ b/compiler/mir-onnx-importer/Op/MaxPool.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class MaxPoolNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Mul.cpp b/compiler/mir-onnx-importer/Op/Mul.cpp new file mode 100644 index 0000000..e095233 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Mul.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Mul.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ElementwiseOp.h" + +namespace mir_onnx +{ + +std::vector +MulNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + auto result = + createOp(graph, inputs, mir::ops::ElementwiseOp::OpType::mul); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Mul.h b/compiler/mir-onnx-importer/Op/Mul.h new file mode 100644 index 0000000..a25cf23 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Mul.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class MulNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Pad.cpp b/compiler/mir-onnx-importer/Op/Pad.cpp new file mode 100644 index 0000000..c3d3a68 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Pad.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Pad.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/PadOp.h" + +namespace mir_onnx +{ + +std::vector +PadNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool found; + float value; + std::tie(found, value) = getFloatAttribute(onnx_node, "value"); + assert(found); + auto padsAtt = findAttribute(onnx_node, "pads"); + assert(padsAtt); + auto modeAtt = findAttribute(onnx_node, "mode"); + assert(modeAtt); + auto mode = modeAtt->s(); + const mir::Scalar scalar(reinterpret_cast(&value), mir::DTYPE::FLOAT32, + sizeof(float)); + assert(padsAtt->ints_size() > 0); + int axis_size = padsAtt->ints_size() / 2; + std::vector> vec(axis_size); + auto *data = padsAtt->ints().data(); + for (int i = 0; i < axis_size; i++) + { + auto pair = std::make_pair(data[i], data[axis_size + i]); + vec[i] = pair; + } + auto result = + createOp(graph, inputs[0], inputs[0]->getShape().rank(), vec, scalar); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Pad.h b/compiler/mir-onnx-importer/Op/Pad.h new file mode 100644 index 0000000..a2801af --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Pad.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class PadNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Relu.cpp b/compiler/mir-onnx-importer/Op/Relu.cpp new file mode 100644 index 0000000..500e449 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Relu.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Relu.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ReluOp.h" + +namespace mir_onnx +{ + +std::vector +ReluNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + assert(inputs.size() == 1); + auto result = createOp(graph, inputs[0]); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Relu.h b/compiler/mir-onnx-importer/Op/Relu.h new file mode 100644 index 0000000..4b2ee9e --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Relu.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ReluNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Reshape.cpp b/compiler/mir-onnx-importer/Op/Reshape.cpp new file mode 100644 index 0000000..f764b3e --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Reshape.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Reshape.h" + +#include "ONNXHelpers.h" + +#include "mir/Tensor.h" + +#include "mir/ops/ConstantOp.h" +#include "mir/ops/ReshapeOp.h" + +namespace mir_onnx +{ + +std::vector +ReshapeNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + // The original shape + const auto &in_shape = inputs[0]->getShape(); + + // Input tensor describing the new shape + // TODO: could it be not a constant? + auto *op = dynamic_cast(inputs[1]->getNode()); + assert(op && "We support constants only"); + auto shape_tensor = op->getValue(); + mir::Shape shape_tensor_shape = (shape_tensor).getShape(); + assert(shape_tensor_shape.rank() == 1); + // The rank of the new shape + auto cnt = shape_tensor_shape.numElements(); + // The vector to build the new shape from + std::vector shape_vector(cnt); + mir::ShapeRange out_range(shape_tensor_shape); + mir::Tensor tensor_accessor(shape_tensor); + + int i = 0; + for (auto idx : out_range) + { + if (tensor_accessor.at(idx) == 0) + shape_vector[i] = in_shape.dim(i); + else if (tensor_accessor.at(idx) == -1) + shape_vector[i] = mir::Shape::autoDim; + else + shape_vector[i] = tensor_accessor.at(idx); + i++; + } + auto out_shape = mir::Shape(shape_vector); + auto result = createOp(graph, inputs[0], out_shape); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Reshape.h b/compiler/mir-onnx-importer/Op/Reshape.h new file mode 100644 index 0000000..c8558d8 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Reshape.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ReshapeNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Scale.cpp b/compiler/mir-onnx-importer/Op/Scale.cpp new file mode 100644 index 0000000..f888a53 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Scale.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Scale.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ConstantOp.h" +#include "mir/ops/ScaleOp.h" + +namespace mir_onnx +{ + +std::vector +ScaleNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool found; + float value; + std::tie(found, value) = getFloatAttribute(onnx_node, "scale"); + float scale_val = found ? value : 1.0; + const auto &shape = inputs[0]->getShape(); + auto scale_tensor = createScalarTensor(scale_val, shape); + auto scale = createOp(graph, scale_tensor)->getOutput(0); + auto result = createOp(graph, inputs[0], scale); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Scale.h b/compiler/mir-onnx-importer/Op/Scale.h new file mode 100644 index 0000000..55c2e3a --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Scale.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ScaleNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Shape.cpp b/compiler/mir-onnx-importer/Op/Shape.cpp new file mode 100644 index 0000000..7344d45 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Shape.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Shape.h" + +#include "ONNXHelpers.h" + +#include "mir/TensorVariant.h" + +#include "mir/ops/ConstantOp.h" + +namespace mir_onnx +{ + +std::vector +ShapeNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + const auto &input_shape = inputs[0]->getShape(); + int size = input_shape.rank(); + mir::Shape output_shape{size}; + std::vector data(static_cast(size)); + for (int i = 0; i < size; i++) + { + data[i] = input_shape.dim(i); + } + mir::TensorVariant tensor(mir::DTYPE::FLOAT32, output_shape, data.data()); + auto result = createOp(graph, tensor); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Shape.h b/compiler/mir-onnx-importer/Op/Shape.h new file mode 100644 index 0000000..52ab97f --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Shape.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ShapeNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Sigmoid.cpp b/compiler/mir-onnx-importer/Op/Sigmoid.cpp new file mode 100644 index 0000000..e537b07 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Sigmoid.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Sigmoid.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/SigmoidOp.h" + +namespace mir_onnx +{ + +std::vector +SigmoidNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + assert(inputs.size() == 1); + auto result = createOp(graph, inputs[0]); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Sigmoid.h b/compiler/mir-onnx-importer/Op/Sigmoid.h new file mode 100644 index 0000000..b738c23 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Sigmoid.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class SigmoidNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Softmax.cpp b/compiler/mir-onnx-importer/Op/Softmax.cpp new file mode 100644 index 0000000..7fc338d --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Softmax.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Softmax.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/SoftmaxOp.h" + +namespace mir_onnx +{ + +std::vector +SoftmaxNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + int axis; + bool found; + std::tie(found, axis) = getIntAttribute(onnx_node); + axis = found ? axis : 1; + auto result = createOp(graph, inputs[0], axis); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Softmax.h b/compiler/mir-onnx-importer/Op/Softmax.h new file mode 100644 index 0000000..4600ee7 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Softmax.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class SoftmaxNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Sum.cpp b/compiler/mir-onnx-importer/Op/Sum.cpp new file mode 100644 index 0000000..d2fa94c --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Sum.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Sum.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ElementwiseOp.h" + +namespace mir_onnx +{ + +std::vector +SumNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + auto result = + createOp(graph, inputs, mir::ops::ElementwiseOp::OpType::add); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Sum.h b/compiler/mir-onnx-importer/Op/Sum.h new file mode 100644 index 0000000..a9c64b8 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Sum.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class SumNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Unsqueeze.cpp b/compiler/mir-onnx-importer/Op/Unsqueeze.cpp new file mode 100644 index 0000000..9487b90 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Unsqueeze.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Unsqueeze.h" + +#include "ONNXHelpers.h" + +#include "mir/ops/ReshapeOp.h" + +namespace mir_onnx +{ + +std::vector +UnsqueezeNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + auto *axes = findAttribute(onnx_node, "axes"); + assert(axes && axes->ints_size()); + const mir::Shape &input_shape = inputs[0]->getShape(); + const int out_rank = input_shape.rank() + axes->ints_size(); + mir::Shape out_shape(out_rank); + auto ints_iterator = axes->ints().begin(); + int j = 0; + for (int i = 0; i < out_rank; i++) + { + if (ints_iterator < axes->ints().end() && i == *ints_iterator) + { + out_shape.dim(i) = 1; + ints_iterator++; + } + else + { + out_shape.dim(i) = input_shape.dim(j); + j++; + } + } + auto result = createOp(graph, inputs[0], out_shape); + return {result->getOutput(0)}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Unsqueeze.h b/compiler/mir-onnx-importer/Op/Unsqueeze.h new file mode 100644 index 0000000..c9eae0b --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Unsqueeze.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class UnsqueezeNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Upsample.cpp b/compiler/mir-onnx-importer/Op/Upsample.cpp new file mode 100644 index 0000000..4353c3a --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Upsample.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Upsample.h" + +#include "ONNXHelpers.h" + +#include "mir/Tensor.h" + +#include "mir/ops/ConstantOp.h" +#include "mir/ops/ResizeOp.h" + +namespace mir_onnx +{ + +std::vector +UpsampleNodeConverter::convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const +{ + bool success; + std::string mode; + std::tie(success, mode) = getStringAttribute(onnx_node, "mode"); + if (!success) + mode = "nearest"; + assert(mode == "nearest" && "Unsupported upscale mode!"); + + // relies on attributes being lifted to constants (ONNX optimization pass) + assert(inputs.size() > 1); + auto *scales = dynamic_cast(inputs[1]->getNode()); + assert(scales && "Weights could be a constant tensor only"); + auto scales_tensor = mir::Tensor(scales->getValue()); + int rank = inputs[0]->getShape().rank(); + assert(scales_tensor.getShape().numElements() == rank && + "The number of elements of 'scales' should be the same as the rank of input 'X'"); + assert(rank == 4 && "Only rank 4 is supported"); + std::vector scales_vector(4); + const int onnx2mir[] = {0, 3, 1, 2}; + assert(scales_tensor.getShape().rank() == 1 && "Scales are a 1d tensor"); + for (int i = 0; i < scales_tensor.getShape().numElements(); i++) + scales_vector[onnx2mir[i]] = scales_tensor.atOffset(i); + return {convertMIRToONNX( + graph, + createOp(graph, convertONNXToMIR(graph, inputs[0]), + mir::ops::ResizeOp::ResizeMethod::nearestNeighbor, scales_vector) + ->getOutput(0))}; +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/Upsample.h b/compiler/mir-onnx-importer/Op/Upsample.h new file mode 100644 index 0000000..9c2d2a5 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/Upsample.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class UpsampleNodeConverter : public NodeConverter +{ +public: + std::vector convert(const onnx::NodeProto &onnx_node, + const std::vector &inputs, + mir::Graph *graph) const override; +}; + +} // namespace mir_onnx -- 2.7.4