From: Сергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 Date: Wed, 12 Dec 2018 10:00:58 +0000 (+0300) Subject: [nnc] Introduce Gather operation (#2626) X-Git-Tag: nncc_backup~1100 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ac7a4af64a600647356bef637af6ecf87990bb23;p=platform%2Fcore%2Fml%2Fnnfw.git [nnc] Introduce Gather operation (#2626) Add GatherOp class to modelIR. Signed-off-by: Sergei Barannikov --- diff --git a/contrib/nnc/core/CMakeLists.txt b/contrib/nnc/core/CMakeLists.txt index cc4f450..52aba7f 100644 --- a/contrib/nnc/core/CMakeLists.txt +++ b/contrib/nnc/core/CMakeLists.txt @@ -3,6 +3,7 @@ set(SOURCES "modelIR/operations/ConcatOp.cpp" "modelIR/operations/DeConv2DOp.cpp" "modelIR/operations/DepthwiseConv2DOp.cpp" "modelIR/operations/FullyConnectedOp.cpp" + "modelIR/operations/GatherOp.cpp" "modelIR/operations/PadOp.cpp" "modelIR/operations/PoolOp.cpp" "modelIR/operations/SqueezeOp.cpp" diff --git a/contrib/nnc/core/modelIR/IrDotDumper.cpp b/contrib/nnc/core/modelIR/IrDotDumper.cpp index b43961c..c0ad967 100644 --- a/contrib/nnc/core/modelIR/IrDotDumper.cpp +++ b/contrib/nnc/core/modelIR/IrDotDumper.cpp @@ -17,6 +17,31 @@ #include #include "core/modelIR/IrDotDumper.h" +#include "core/modelIR/operations/BatchNormOp.h" +#include "core/modelIR/operations/BiasAddOp.h" +#include "core/modelIR/operations/CappedReluOp.h" +#include "core/modelIR/operations/ConcatOp.h" +#include "core/modelIR/operations/ConstantOp.h" +#include "core/modelIR/operations/Conv2DOp.h" +#include "core/modelIR/operations/Deconv2DOp.h" +#include "core/modelIR/operations/DepthwiseConv2DOp.h" +#include "core/modelIR/operations/DropoutOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/EluOp.h" +#include "core/modelIR/operations/FullyConnectedOp.h" +#include "core/modelIR/operations/GatherOp.h" +#include "core/modelIR/operations/PadOp.h" +#include "core/modelIR/operations/PoolOp.h" +#include "core/modelIR/operations/ReduceFOp.h" +#include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" +#include "core/modelIR/operations/ScaleOp.h" +#include "core/modelIR/operations/SoftmaxOp.h" +#include "core/modelIR/operations/SqueezeOp.h" +#include "core/modelIR/operations/TanhOp.h" +#include "core/modelIR/operations/TransposeOp.h" +#include "core/modelIR/operations/VariableOp.h" namespace nnc { namespace mir { @@ -264,6 +289,10 @@ void IrDotDumper::visit(ops::TransposeOp& op) { dotBuilder.updateWithOp(&op, node_info); } +void IrDotDumper::visit(ops::GatherOp& op) { + auto node_info = DotIrNodeInfo().withType("GatherOp", op.getName()); +} + } // namespace mir } // namespace nnc diff --git a/contrib/nnc/core/modelIR/Operation.cpp b/contrib/nnc/core/modelIR/Operation.cpp index 4952d32..2005ea6 100644 --- a/contrib/nnc/core/modelIR/Operation.cpp +++ b/contrib/nnc/core/modelIR/Operation.cpp @@ -15,30 +15,31 @@ */ #include "core/modelIR/Operation.h" -#include "core/modelIR/operations/FullyConnectedOp.h" -#include "core/modelIR/operations/SoftmaxOp.h" +#include "core/modelIR/operations/BatchNormOp.h" +#include "core/modelIR/operations/BiasAddOp.h" #include "core/modelIR/operations/CappedReluOp.h" -#include "core/modelIR/operations/DepthwiseConv2DOp.h" +#include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/ConstantOp.h" #include "core/modelIR/operations/Conv2DOp.h" #include "core/modelIR/operations/Deconv2DOp.h" +#include "core/modelIR/operations/DepthwiseConv2DOp.h" +#include "core/modelIR/operations/DropoutOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/EluOp.h" +#include "core/modelIR/operations/FullyConnectedOp.h" +#include "core/modelIR/operations/GatherOp.h" +#include "core/modelIR/operations/PadOp.h" #include "core/modelIR/operations/PoolOp.h" -#include "core/modelIR/operations/VariableOp.h" +#include "core/modelIR/operations/ReduceFOp.h" #include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ReshapeOp.h" #include "core/modelIR/operations/ResizeOp.h" -#include "core/modelIR/operations/EluOp.h" -#include "core/modelIR/operations/ConcatOp.h" -#include "core/modelIR/operations/BiasAddOp.h" -#include "core/modelIR/operations/BatchNormOp.h" #include "core/modelIR/operations/ScaleOp.h" -#include "core/modelIR/operations/DropoutOp.h" -#include "core/modelIR/operations/TanhOp.h" -#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/SoftmaxOp.h" #include "core/modelIR/operations/SqueezeOp.h" -#include "core/modelIR/operations/ReshapeOp.h" -#include "core/modelIR/operations/PadOp.h" -#include "core/modelIR/operations/ReduceFOp.h" +#include "core/modelIR/operations/TanhOp.h" #include "core/modelIR/operations/TransposeOp.h" +#include "core/modelIR/operations/VariableOp.h" namespace nnc { namespace mir { diff --git a/contrib/nnc/core/modelIR/operations/GatherOp.cpp b/contrib/nnc/core/modelIR/operations/GatherOp.cpp new file mode 100644 index 0000000..e0d1bae --- /dev/null +++ b/contrib/nnc/core/modelIR/operations/GatherOp.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/modelIR/operations/GatherOp.h" + +namespace nnc { +namespace mir { +namespace ops { + +void GatherOp::inferOutputShapes() { + const auto& data_shape = getInputShape(0); + const auto& indices_shape = getInputShape(1); + + auto data_rank = data_shape.rank(); + auto indices_rank = indices_shape.rank(); + auto output_rank = data_rank + indices_rank - 1; + + assert(_axis >= -data_rank && _axis < data_rank); + int32_t axis = _axis < 0 ? _axis + data_rank : _axis; + + Shape output_shape; + output_shape.resize(output_rank); + + // Output shape is data.shape[:axis] + indices.shape + data.shape[axis + 1:]. + int32_t output_index = 0; + for (int32_t i = 0; i < axis; ++i) + output_shape.dim(output_index++) = data_shape.dim(i); + for (int32_t i = 0; i < indices_rank; ++i) + output_shape.dim(output_index++) = indices_shape.dim(i); + for (int32_t i = axis + 1; i < data_rank; ++i) + output_shape.dim(output_index++) = data_shape.dim(i); + + setOutputShape(0, output_shape); +} + +} // namespace ops +} // namespace mir +} // namespace nnc diff --git a/contrib/nnc/include/core/modelIR/IrDotDumper.h b/contrib/nnc/include/core/modelIR/IrDotDumper.h index cc54ab3..f5a907a 100644 --- a/contrib/nnc/include/core/modelIR/IrDotDumper.h +++ b/contrib/nnc/include/core/modelIR/IrDotDumper.h @@ -18,30 +18,6 @@ #define _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_ #include "core/modelIR/Visitor.h" -#include "core/modelIR/operations/FullyConnectedOp.h" -#include "core/modelIR/operations/SoftmaxOp.h" -#include "core/modelIR/operations/CappedReluOp.h" -#include "core/modelIR/operations/ConstantOp.h" -#include "core/modelIR/operations/DepthwiseConv2DOp.h" -#include "core/modelIR/operations/Conv2DOp.h" -#include "core/modelIR/operations/Deconv2DOp.h" -#include "core/modelIR/operations/PoolOp.h" -#include "core/modelIR/operations/VariableOp.h" -#include "core/modelIR/operations/ReluOp.h" -#include "core/modelIR/operations/EluOp.h" -#include "core/modelIR/operations/ConcatOp.h" -#include "core/modelIR/operations/BiasAddOp.h" -#include "core/modelIR/operations/ReshapeOp.h" -#include "core/modelIR/operations/ResizeOp.h" -#include "core/modelIR/operations/BatchNormOp.h" -#include "core/modelIR/operations/ScaleOp.h" -#include "core/modelIR/operations/DropoutOp.h" -#include "core/modelIR/operations/TanhOp.h" -#include "core/modelIR/operations/ElementwiseOp.h" -#include "core/modelIR/operations/SqueezeOp.h" -#include "core/modelIR/operations/PadOp.h" -#include "core/modelIR/operations/ReduceFOp.h" -#include "core/modelIR/operations/TransposeOp.h" #include "core/modelIR/ir_dot_builder.h" @@ -56,30 +32,32 @@ namespace mir */ class IrDotDumper : public IVisitor { public: + void visit(ops::BatchNormOp& op) override; + void visit(ops::BiasAddOp& op) override; + void visit(ops::CappedReluOp& op) override; void visit(ops::ConcatOp& op) override; void visit(ops::ConstantOp& op) override; - void visit(ops::ReluOp& op) override; void visit(ops::Conv2DOp& op) override; + void visit(ops::DeConv2DOp& op) override; void visit(ops::DepthwiseConv2DOp& op) override; - void visit(ops::SoftmaxOp& op) override; - void visit(ops::PoolOp& op) override; + void visit(ops::DropoutOp& op) override; + void visit(ops::ElementwiseOp& op) override; + void visit(ops::EluOp& op) override; void visit(ops::FullyConnectedOp& op) override; - void visit(ops::CappedReluOp& op) override; - void visit(ops::BiasAddOp& op) override; - void visit(ops::VariableOp& op) override; + void visit(ops::GatherOp& op) override; + void visit(ops::PadOp& op) override; + void visit(ops::PoolOp& op) override; + void visit(ops::ReduceFOp& op) override; + void visit(ops::ReluOp& op) override; void visit(ops::ReshapeOp& op) override; void visit(ops::ResizeOp& op) override; void visit(ops::ScaleOp& op) override; - void visit(ops::BatchNormOp& op) override; - void visit(ops::DropoutOp& op) override; - void visit(ops::DeConv2DOp& op) override; - void visit(ops::EluOp& op) override; - void visit(ops::TanhOp& op) override; - void visit(ops::ElementwiseOp& op) override; + void visit(ops::SoftmaxOp& op) override; void visit(ops::SqueezeOp& op) override; - void visit(ops::PadOp& op) override; - void visit(ops::ReduceFOp& op) override; + void visit(ops::TanhOp& op) override; void visit(ops::TransposeOp& op) override; + void visit(ops::VariableOp& op) override; + void writeDot(std::ostream &os) { dotBuilder.writeDot(os); }; diff --git a/contrib/nnc/include/core/modelIR/operations/GatherOp.h b/contrib/nnc/include/core/modelIR/operations/GatherOp.h new file mode 100644 index 0000000..3c9bbc4 --- /dev/null +++ b/contrib/nnc/include/core/modelIR/operations/GatherOp.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _NNC_CORE_IR_MODEL_GATHER_H_ +#define _NNC_CORE_IR_MODEL_GATHER_H_ + +#include "core/modelIR/Operation.h" + +namespace nnc { +namespace mir { +namespace ops { + +/** + * @brief Gather operation as defined by ONNX spec. + * https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gather + * https://www.tensorflow.org/api_docs/python/tf/gather + */ +class GatherOp : public Operation { +public: + GatherOp(const IODescriptor& data, const IODescriptor& indices, int32_t axis) + : Operation(Type::gather, {data, indices}), _axis(axis) { + inferOutputShapes(); + } + +private: + void inferOutputShapes(); + + int32_t _axis; +}; + +} // namespace ops +} // namespace mir +} // namespace nnc + +#endif //_NNC_CORE_IR_MODEL_GATHER_H_ diff --git a/contrib/nnc/include/core/modelIR/operations/operations.lst.h b/contrib/nnc/include/core/modelIR/operations/operations.lst.h index 040070b..c4ae1d2 100644 --- a/contrib/nnc/include/core/modelIR/operations/operations.lst.h +++ b/contrib/nnc/include/core/modelIR/operations/operations.lst.h @@ -21,6 +21,7 @@ HANDLE_OP(concat, ConcatOp) HANDLE_OP(conv2D, Conv2DOp) HANDLE_OP(depthwiseConv, DepthwiseConv2DOp) +HANDLE_OP(gather, GatherOp) HANDLE_OP(softmax, SoftmaxOp) HANDLE_OP(pool, PoolOp) HANDLE_OP(fullyConnected, FullyConnectedOp) diff --git a/contrib/nnc/include/passes/interpreter/Interpreter.h b/contrib/nnc/include/passes/interpreter/Interpreter.h index ce4d938..f1ddba3 100644 --- a/contrib/nnc/include/passes/interpreter/Interpreter.h +++ b/contrib/nnc/include/passes/interpreter/Interpreter.h @@ -36,30 +36,31 @@ class NNInterpreter : public IVisitor { public: explicit NNInterpreter() = default; + void visit(ops::BatchNormOp& op) override; + void visit(ops::BiasAddOp& op) override; + void visit(ops::CappedReluOp& op) override; void visit(ops::ConcatOp& op) override; void visit(ops::ConstantOp& op) override; void visit(ops::Conv2DOp& op) override; + void visit(ops::DeConv2DOp& op) override; void visit(ops::DepthwiseConv2DOp& op) override; - void visit(ops::ReluOp& op) override; - void visit(ops::SoftmaxOp& op) override; - void visit(ops::PoolOp& op) override; - void visit(ops::FullyConnectedOp& op) override; - void visit(ops::CappedReluOp& op) override; - void visit(ops::BiasAddOp& op) override; - void visit(ops::VariableOp& op) override; - void visit(ops::ReshapeOp& op) override; - void visit(ops::ResizeOp& op) override; - void visit(ops::ScaleOp& op) override; - void visit(ops::BatchNormOp& op) override; void visit(ops::DropoutOp& op) override; - void visit(ops::TanhOp& op) override; void visit(ops::ElementwiseOp& op) override; - void visit(ops::DeConv2DOp& op) override; void visit(ops::EluOp& op) override; - void visit(ops::SqueezeOp& op) override; + void visit(ops::FullyConnectedOp& op) override; + void visit(ops::GatherOp& op) override; void visit(ops::PadOp& op) override; + void visit(ops::PoolOp& op) override; void visit(ops::ReduceFOp& op) override; + void visit(ops::ReluOp& op) override; + void visit(ops::ReshapeOp& op) override; + void visit(ops::ResizeOp& op) override; + void visit(ops::ScaleOp& op) override; + void visit(ops::SoftmaxOp& op) override; + void visit(ops::SqueezeOp& op) override; + void visit(ops::TanhOp& op) override; void visit(ops::TransposeOp& op) override; + void visit(ops::VariableOp& op) override; void setInput(const std::string &name, const TensorVariant& data); std::vector &getResult(Operation* op); diff --git a/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp b/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp index 1b32ab2..5516041 100644 --- a/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp +++ b/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp @@ -6,26 +6,26 @@ #include "core/modelIR/Tensor.h" #include "core/modelIR/Operation.h" -#include "core/modelIR/operations/VariableOp.h" -#include "core/modelIR/operations/SoftmaxOp.h" -#include "core/modelIR/operations/Conv2DOp.h" -#include "core/modelIR/operations/ConstantOp.h" -#include "core/modelIR/operations/ScaleOp.h" #include "core/modelIR/operations/BatchNormOp.h" -#include "core/modelIR/operations/DropoutOp.h" -#include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/BiasAddOp.h" #include "core/modelIR/operations/CappedReluOp.h" -#include "core/modelIR/operations/TanhOp.h" -#include "core/modelIR/operations/ReshapeOp.h" -#include "core/modelIR/operations/ResizeOp.h" +#include "core/modelIR/operations/ConcatOp.h" +#include "core/modelIR/operations/ConstantOp.h" +#include "core/modelIR/operations/Conv2DOp.h" +#include "core/modelIR/operations/Deconv2DOp.h" #include "core/modelIR/operations/DepthwiseConv2DOp.h" +#include "core/modelIR/operations/DropoutOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" #include "core/modelIR/operations/FullyConnectedOp.h" -#include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/PoolOp.h" -#include "core/modelIR/operations/BiasAddOp.h" -#include "core/modelIR/operations/ElementwiseOp.h" -#include "core/modelIR/operations/Deconv2DOp.h" #include "core/modelIR/operations/ReduceFOp.h" +#include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" +#include "core/modelIR/operations/ScaleOp.h" +#include "core/modelIR/operations/SoftmaxOp.h" +#include "core/modelIR/operations/TanhOp.h" +#include "core/modelIR/operations/VariableOp.h" #include @@ -883,5 +883,9 @@ void AclCppOpGenerator::visit(mir::ops::TransposeOp& op) { assert(false && "Unimplemented operation: TransposeOp"); } +void AclCppOpGenerator::visit(mir::ops::GatherOp& op) { + assert(false && "Unimplemented operation: GatherOp"); +} + } // namespace nnc diff --git a/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h b/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h index 077049b..b6108fc 100644 --- a/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h +++ b/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h @@ -47,30 +47,31 @@ public: * @brief Implementations of the IVisitor visitors. * @param op */ + void visit(mir::ops::BatchNormOp& op) override; + void visit(mir::ops::BiasAddOp& op) override; + void visit(mir::ops::CappedReluOp& op) override; void visit(mir::ops::ConcatOp& op) override; void visit(mir::ops::ConstantOp& op) override; void visit(mir::ops::Conv2DOp& op) override; + void visit(mir::ops::DeConv2DOp& op) override; void visit(mir::ops::DepthwiseConv2DOp& op) override; - void visit(mir::ops::SoftmaxOp& op) override; - void visit(mir::ops::PoolOp& op) override; + void visit(mir::ops::DropoutOp& op) override; + void visit(mir::ops::ElementwiseOp& op) override; + void visit(mir::ops::EluOp& op) override; void visit(mir::ops::FullyConnectedOp& op) override; - void visit(mir::ops::CappedReluOp& op) override; - void visit(mir::ops::BiasAddOp& op) override; - void visit(mir::ops::VariableOp& op) override; + void visit(mir::ops::GatherOp& op) override; + void visit(mir::ops::PadOp& op) override; + void visit(mir::ops::PoolOp& op) override; + void visit(mir::ops::ReduceFOp& op) override; void visit(mir::ops::ReluOp& op) override; void visit(mir::ops::ReshapeOp& op) override; void visit(mir::ops::ResizeOp& op) override; void visit(mir::ops::ScaleOp& op) override; - void visit(mir::ops::BatchNormOp& op) override; - void visit(mir::ops::DropoutOp& op) override; - void visit(mir::ops::TanhOp& op) override; - void visit(mir::ops::ElementwiseOp& op) override; - void visit(mir::ops::DeConv2DOp& op) override; - void visit(mir::ops::EluOp& op) override; + void visit(mir::ops::SoftmaxOp& op) override; void visit(mir::ops::SqueezeOp& op) override; - void visit(mir::ops::PadOp& op) override; - void visit(mir::ops::ReduceFOp& op) override; + void visit(mir::ops::TanhOp& op) override; void visit(mir::ops::TransposeOp& op) override; + void visit(mir::ops::VariableOp& op) override; private: using AF = ArtifactFactory; diff --git a/contrib/nnc/passes/interpreter/Interpreter.cpp b/contrib/nnc/passes/interpreter/Interpreter.cpp index a908b39..76cd84a 100644 --- a/contrib/nnc/passes/interpreter/Interpreter.cpp +++ b/contrib/nnc/passes/interpreter/Interpreter.cpp @@ -347,4 +347,8 @@ void NNInterpreter::visit(ops::TransposeOp& op) { var(op.getId()) = Transpose(input, op)(); } +void NNInterpreter::visit(ops::GatherOp& op) { + assert(false && "Not yet imlemented"); +} + } // namespace nnc diff --git a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp index 8cbbcb5..3fe75d3 100644 --- a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp +++ b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp @@ -24,29 +24,30 @@ #include "core/modelIR/ShapeRange.h" #include "core/modelIR/Graph.h" +#include "core/modelIR/operations/BatchNormOp.h" +#include "core/modelIR/operations/BiasAddOp.h" +#include "core/modelIR/operations/CappedReluOp.h" #include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/ConstantOp.h" #include "core/modelIR/operations/Conv2DOp.h" #include "core/modelIR/operations/Deconv2DOp.h" #include "core/modelIR/operations/DepthwiseConv2DOp.h" -#include "core/modelIR/operations/SoftmaxOp.h" -#include "core/modelIR/operations/PoolOp.h" +#include "core/modelIR/operations/DropoutOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/EluOp.h" #include "core/modelIR/operations/FullyConnectedOp.h" -#include "core/modelIR/operations/CappedReluOp.h" -#include "core/modelIR/operations/BiasAddOp.h" +#include "core/modelIR/operations/GatherOp.h" +#include "core/modelIR/operations/PadOp.h" +#include "core/modelIR/operations/PoolOp.h" +#include "core/modelIR/operations/ReduceFOp.h" #include "core/modelIR/operations/ReluOp.h" -#include "core/modelIR/operations/EluOp.h" #include "core/modelIR/operations/ReshapeOp.h" -#include "core/modelIR/operations/BatchNormOp.h" #include "core/modelIR/operations/ScaleOp.h" -#include "core/modelIR/operations/DropoutOp.h" -#include "core/modelIR/operations/TanhOp.h" -#include "core/modelIR/operations/ElementwiseOp.h" -#include "core/modelIR/operations/VariableOp.h" +#include "core/modelIR/operations/SoftmaxOp.h" #include "core/modelIR/operations/SqueezeOp.h" -#include "core/modelIR/operations/PadOp.h" -#include "core/modelIR/operations/ReduceFOp.h" +#include "core/modelIR/operations/TanhOp.h" #include "core/modelIR/operations/TransposeOp.h" +#include "core/modelIR/operations/VariableOp.h" using namespace std; @@ -303,4 +304,8 @@ void ModelAnalyzer::visit(mir::ops::TransposeOp& op) { addOpDescr(&op, "transpose"); } +void ModelAnalyzer::visit(mir::ops::GatherOp& op) { + addOpDescr(&op, "gather"); +} + } // namespace nnc diff --git a/contrib/nnc/passes/soft_backend/ModelAnalyzer.h b/contrib/nnc/passes/soft_backend/ModelAnalyzer.h index 0b16711..a9ce551 100644 --- a/contrib/nnc/passes/soft_backend/ModelAnalyzer.h +++ b/contrib/nnc/passes/soft_backend/ModelAnalyzer.h @@ -89,30 +89,31 @@ public: */ void analyze(const mir::Graph* g); + void visit(mir::ops::BatchNormOp& op) override; + void visit(mir::ops::BiasAddOp& op) override; + void visit(mir::ops::CappedReluOp& op) override; void visit(mir::ops::ConcatOp& op) override; void visit(mir::ops::ConstantOp& op) override; void visit(mir::ops::Conv2DOp& op) override; + void visit(mir::ops::DeConv2DOp& op) override; void visit(mir::ops::DepthwiseConv2DOp& op) override; - void visit(mir::ops::SoftmaxOp& op) override; - void visit(mir::ops::PoolOp& op) override; + void visit(mir::ops::DropoutOp& op) override; + void visit(mir::ops::ElementwiseOp& op) override; + void visit(mir::ops::EluOp& op) override; void visit(mir::ops::FullyConnectedOp& op) override; - void visit(mir::ops::CappedReluOp& op) override; - void visit(mir::ops::BiasAddOp& op) override; - void visit(mir::ops::VariableOp& op) override; + void visit(mir::ops::GatherOp& op) override; + void visit(mir::ops::PadOp& op) override; + void visit(mir::ops::PoolOp& op) override; + void visit(mir::ops::ReduceFOp& op) override; void visit(mir::ops::ReluOp& op) override; void visit(mir::ops::ReshapeOp& op) override; void visit(mir::ops::ResizeOp& op) override; void visit(mir::ops::ScaleOp& op) override; - void visit(mir::ops::BatchNormOp& op) override; - void visit(mir::ops::DropoutOp& op) override; - void visit(mir::ops::TanhOp& op) override; - void visit(mir::ops::ElementwiseOp& op) override; - void visit(mir::ops::DeConv2DOp& op) override; - void visit(mir::ops::EluOp& op) override; + void visit(mir::ops::SoftmaxOp& op) override; void visit(mir::ops::SqueezeOp& op) override; - void visit(mir::ops::PadOp& op) override; - void visit(mir::ops::ReduceFOp& op) override; + void visit(mir::ops::TanhOp& op) override; void visit(mir::ops::TransposeOp& op) override; + void visit(mir::ops::VariableOp& op) override; /** * @return vector of id's of network input tensors diff --git a/contrib/nnc/passes/soft_backend/SBSerializer.cpp b/contrib/nnc/passes/soft_backend/SBSerializer.cpp index fe7b9e4..c2241ff 100644 --- a/contrib/nnc/passes/soft_backend/SBSerializer.cpp +++ b/contrib/nnc/passes/soft_backend/SBSerializer.cpp @@ -363,4 +363,9 @@ void Serializer::visit(mir::ops::TransposeOp& op) { serializeShape(op.getOutputShape(0)); } +void Serializer::visit(mir::ops::GatherOp& op) { + _curOp->_paramStartOffset = _buffer.size(); + assert(false && "Not yet implemented"); +} + } // namespace nnc diff --git a/contrib/nnc/passes/soft_backend/SBSerializer.h b/contrib/nnc/passes/soft_backend/SBSerializer.h index dcac3cb..748db4e 100644 --- a/contrib/nnc/passes/soft_backend/SBSerializer.h +++ b/contrib/nnc/passes/soft_backend/SBSerializer.h @@ -41,30 +41,31 @@ namespace nnc class Serializer: public mir::IVisitor { public: + void visit(mir::ops::BatchNormOp& op) override; + void visit(mir::ops::BiasAddOp& op) override; + void visit(mir::ops::CappedReluOp& op) override; void visit(mir::ops::ConcatOp& op) override; void visit(mir::ops::ConstantOp& op) override; void visit(mir::ops::Conv2DOp& op) override; + void visit(mir::ops::DeConv2DOp& op) override; void visit(mir::ops::DepthwiseConv2DOp& op) override; - void visit(mir::ops::SoftmaxOp& op) override; - void visit(mir::ops::PoolOp& op) override; + void visit(mir::ops::DropoutOp& op) override; + void visit(mir::ops::ElementwiseOp& op) override; + void visit(mir::ops::EluOp& op) override; void visit(mir::ops::FullyConnectedOp& op) override; - void visit(mir::ops::CappedReluOp& op) override; - void visit(mir::ops::BiasAddOp& op) override; - void visit(mir::ops::VariableOp& op) override; + void visit(mir::ops::GatherOp& op) override; + void visit(mir::ops::PadOp& op) override; + void visit(mir::ops::PoolOp& op) override; + void visit(mir::ops::ReduceFOp& op) override; void visit(mir::ops::ReluOp& op) override; void visit(mir::ops::ReshapeOp& op) override; void visit(mir::ops::ResizeOp& op) override; void visit(mir::ops::ScaleOp& op) override; - void visit(mir::ops::BatchNormOp& op) override; - void visit(mir::ops::DropoutOp& op) override; - void visit(mir::ops::TanhOp& op) override; - void visit(mir::ops::ElementwiseOp& op) override; - void visit(mir::ops::DeConv2DOp& op) override; - void visit(mir::ops::EluOp& op) override; + void visit(mir::ops::SoftmaxOp& op) override; void visit(mir::ops::SqueezeOp& op) override; - void visit(mir::ops::PadOp& op) override; - void visit(mir::ops::ReduceFOp& op) override; + void visit(mir::ops::TanhOp& op) override; void visit(mir::ops::TransposeOp& op) override; + void visit(mir::ops::VariableOp& op) override; void serialize(std::list &inferenceSequence);