From: Андрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 Date: Wed, 28 Nov 2018 15:13:26 +0000 (+0300) Subject: [nnc] Add Resize Nearest Neighbor (#2315) X-Git-Tag: nncc_backup~1241 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ffeb6044669d1c967f40d26762fb7270c78296c8;p=platform%2Fcore%2Fml%2Fnnfw.git [nnc] Add Resize Nearest Neighbor (#2315) Added Resize Nearest Neighbor to tflite importer and interpreter. Added shape inference tests some cases. The op allows adding more resize types by just adding new values to the enum and backends. The corresponding Op in ONNX is Upsample and it can be supported without modifications Signed-off-by: Andrei Shedko --- diff --git a/contrib/nnc/core/modelIR/IrDotDumper.cpp b/contrib/nnc/core/modelIR/IrDotDumper.cpp index dc2ef09..c1f0986 100644 --- a/contrib/nnc/core/modelIR/IrDotDumper.cpp +++ b/contrib/nnc/core/modelIR/IrDotDumper.cpp @@ -15,6 +15,7 @@ */ #include + #include "core/modelIR/IrDotDumper.h" namespace nnc { @@ -223,5 +224,15 @@ void mir::IrDotDumper::visit(ops::PadOp& op) { dotBuilder.updateWithOp(&op, node_info); } +void IrDotDumper::visit(ops::ResizeOp& op) { + auto node_info = DotIrNodeInfo().withType("Resize", op.getName()) + .withInShapes(getInputShapes(op)) + .withOutShapes(getOutputShapes(op)) + .withMisc("Mode", (int) op.getMode()); + // scale and resShape are only needed in Shape Inference + + dotBuilder.updateWithOp(&op, node_info); +} + } // namespace mir } // namespace nnc diff --git a/contrib/nnc/core/modelIR/Operation.cpp b/contrib/nnc/core/modelIR/Operation.cpp index c6e3e81..84a65b0 100644 --- a/contrib/nnc/core/modelIR/Operation.cpp +++ b/contrib/nnc/core/modelIR/Operation.cpp @@ -24,6 +24,7 @@ #include "core/modelIR/operations/PoolOp.h" #include "core/modelIR/operations/VariableOp.h" #include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/EluOp.h" #include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/BiasAddOp.h" diff --git a/contrib/nnc/core/modelIR/Shape.cpp b/contrib/nnc/core/modelIR/Shape.cpp index 8867e6b..0ddb7f9 100644 --- a/contrib/nnc/core/modelIR/Shape.cpp +++ b/contrib/nnc/core/modelIR/Shape.cpp @@ -24,6 +24,8 @@ namespace nnc namespace mir { +constexpr int32_t mir::Shape::autoDim; + Shape::Shape(std::initializer_list &&l) : _dims{l} { // DO NOTHING @@ -57,7 +59,7 @@ int32_t Shape::numElements() const for (int32_t axis = 0; axis < rank(); ++axis) { - assert(dim(axis) != Shape::AUTO_DIM); + assert(dim(axis) != Shape::autoDim); res *= dim(axis); } @@ -92,7 +94,7 @@ std::ostream &operator<<(std::ostream &s, const Shape &sh) { if (axis != 0) s << ", "; - if (sh.dim(axis) == Shape::AUTO_DIM) + if (sh.dim(axis) == Shape::autoDim) s << "AUTO"; else s << sh.dim(axis); diff --git a/contrib/nnc/core/modelIR/ShapeInference.cpp b/contrib/nnc/core/modelIR/ShapeInference.cpp index 3ef6069..98d98a1 100644 --- a/contrib/nnc/core/modelIR/ShapeInference.cpp +++ b/contrib/nnc/core/modelIR/ShapeInference.cpp @@ -32,6 +32,7 @@ #include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/BiasAddOp.h" #include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/BatchNormOp.h" #include "core/modelIR/operations/ScaleOp.h" #include "core/modelIR/operations/DropoutOp.h" @@ -149,6 +150,32 @@ void ShapeInference::visit(ops::ReluOp& op) { op.setOutputShape(0, op.getInputShape(0)); } +void ShapeInference::visit(ops::ResizeOp& op) { + fillInputShapes(op); + const auto& in_s = op.getInputShape(0); + Shape out_s = in_s; + auto res_s = op.getResultShape(); + const std::vector& scales = op.getScales(); + + if (scales.size() > 0) { + assert( + in_s.rank() == static_cast(scales.size()) + && "Scaling parameter incompatible with input shape"); + for (int32_t i = 0; i < in_s.rank(); i++) { + out_s.dim(i) = (int32_t)lroundf(scales[i] * in_s.dim(i)); + } + } else { + // Assume batch is cut off + assert(in_s.rank() == 3); + out_s.dim(0) = res_s.dim(0); + out_s.dim(1) = res_s.dim(1); + out_s.dim(2) = in_s.dim(2); + op.setScales({static_cast (out_s.dim(0)) / in_s.dim(0), + static_cast (out_s.dim(1)) / in_s.dim(1), 1.0f}); + } + op.setOutputShape(0, out_s); +} + void ShapeInference::visit(ops::SoftmaxOp& op) { fillInputShapes(op); op.setOutputShape(0, op.getInputShape(0)); @@ -232,14 +259,14 @@ void ShapeInference::visit(ops::ReshapeOp& op) { //can't use num_elements due to -1 in input shape and Shape using unsigned ints for dimensions for( int32_t d = 0; d < outShape.rank(); ++d ) { auto dim = outShape.dim(d); - if( dim != Shape::AUTO_DIM) { + if( dim != Shape::autoDim) { outElementsNum *= dim; } } for( int32_t d = 0; d < outShape.rank(); ++d ) { auto& dim = outShape.dim(d); - if( dim == Shape::AUTO_DIM ) { + if( dim == Shape::autoDim ) { dim = static_cast(inElementsNum / outElementsNum); } } diff --git a/contrib/nnc/include/core/modelIR/IrDotDumper.h b/contrib/nnc/include/core/modelIR/IrDotDumper.h index f053843..567cc82 100644 --- a/contrib/nnc/include/core/modelIR/IrDotDumper.h +++ b/contrib/nnc/include/core/modelIR/IrDotDumper.h @@ -31,6 +31,7 @@ #include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/BiasAddOp.h" #include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/BatchNormOp.h" #include "core/modelIR/operations/ScaleOp.h" #include "core/modelIR/operations/DropoutOp.h" @@ -63,6 +64,7 @@ public: void visit(ops::BiasAddOp& op) override; void visit(ops::VariableOp& op) override; void visit(ops::ReshapeOp& op) override; + void visit(ops::ResizeOp& op) override; void visit(ops::ScaleOp& op) override; void visit(ops::BatchNormOp& op) override; void visit(ops::DropoutOp& op) override; diff --git a/contrib/nnc/include/core/modelIR/Shape.h b/contrib/nnc/include/core/modelIR/Shape.h index e3d5eab..6c245e1 100644 --- a/contrib/nnc/include/core/modelIR/Shape.h +++ b/contrib/nnc/include/core/modelIR/Shape.h @@ -30,7 +30,7 @@ namespace mir class Shape { public: - static const auto AUTO_DIM = static_cast(-1); + static constexpr int32_t autoDim = -1; Shape() = default; Shape(std::initializer_list &&l); diff --git a/contrib/nnc/include/core/modelIR/ShapeInference.h b/contrib/nnc/include/core/modelIR/ShapeInference.h index 3dcd9af..17e9f4a 100644 --- a/contrib/nnc/include/core/modelIR/ShapeInference.h +++ b/contrib/nnc/include/core/modelIR/ShapeInference.h @@ -38,6 +38,7 @@ public: void visit(ops::CappedReluOp& op) override; void visit(ops::BiasAddOp& op) override; void visit(ops::ReshapeOp& op) override; + void visit(ops::ResizeOp& op) override; void visit(ops::VariableOp& op) override; void visit(ops::ScaleOp& op) override; void visit(ops::BatchNormOp& op) override; diff --git a/contrib/nnc/include/core/modelIR/operations/ResizeOp.h b/contrib/nnc/include/core/modelIR/operations/ResizeOp.h new file mode 100644 index 0000000..b18a7f0 --- /dev/null +++ b/contrib/nnc/include/core/modelIR/operations/ResizeOp.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _NNC_CORE_IR_MODEL_RESIZEOP_H_ +#define _NNC_CORE_IR_MODEL_RESIZEOP_H_ + +#include "core/modelIR/Operation.h" +#include "core/modelIR/Shape.h" +#include + +namespace nnc { +namespace mir { +namespace ops { + +class ResizeOp : public Operation { +public: + + enum class ResizeMethod { + nearestNeighbor, // TODO: BICUBIC and BILINEAR + }; + + explicit ResizeOp(const IODescriptor& arg, ResizeMethod mode, const std::vector& scales) : + Operation(Type::resizeIm, {arg}), _mode(mode), _scales(scales), + _resultShape({}) {} + + explicit ResizeOp(const IODescriptor& arg, ResizeMethod mode, const Shape& shape) : + Operation(Type::resizeIm, {arg}), _mode(mode), + _scales({}), _resultShape(shape) {} + + /** @return The resize mode */ + ResizeMethod getMode() const { return _mode; } + + const Shape& getResultShape() const { return _resultShape; } + + const std::vector& getScales() const { return _scales; } + + void setScales(const std::vector& scales) { _scales = scales; } + +private: + std::vector _scales; + Shape _resultShape; + ResizeMethod _mode; +}; + +} // namespace ops +} // namespace mir +} // namespace nnc + + +#endif //_NNC_CORE_IR_MODEL_RESIZEOP_H_ diff --git a/contrib/nnc/include/core/modelIR/operations/operations.lst.h b/contrib/nnc/include/core/modelIR/operations/operations.lst.h index c04286a..c89e716 100644 --- a/contrib/nnc/include/core/modelIR/operations/operations.lst.h +++ b/contrib/nnc/include/core/modelIR/operations/operations.lst.h @@ -29,6 +29,7 @@ HANDLE_OP(biasAdd, BiasAddOp) HANDLE_OP(variable, VariableOp) HANDLE_OP(ReLU, ReluOp) HANDLE_OP(reshape, ReshapeOp) +HANDLE_OP(resizeIm, ResizeOp) HANDLE_OP(scale, ScaleOp) HANDLE_OP(batchNorm, BatchNormOp) HANDLE_OP(dropout, DropoutOp) diff --git a/contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h b/contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h index 6c9af1e..c8c2e7e 100644 --- a/contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h +++ b/contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h @@ -58,6 +58,7 @@ public: void visit(mir::ops::VariableOp& op) override; void visit(mir::ops::ReluOp& op) override; void visit(mir::ops::ReshapeOp& op) override; + void visit(mir::ops::ResizeOp& op) override; void visit(mir::ops::ScaleOp& op) override; void visit(mir::ops::BatchNormOp& op) override; void visit(mir::ops::DropoutOp& op) override; diff --git a/contrib/nnc/include/passes/interpreter/Interpreter.h b/contrib/nnc/include/passes/interpreter/Interpreter.h index 1bc2af0..de0742a 100644 --- a/contrib/nnc/include/passes/interpreter/Interpreter.h +++ b/contrib/nnc/include/passes/interpreter/Interpreter.h @@ -47,6 +47,7 @@ public: void visit(ops::BiasAddOp& op) override; void visit(ops::VariableOp& op) override; void visit(ops::ReshapeOp& op) override; + void visit(ops::ResizeOp& op) override; void visit(ops::ScaleOp& op) override; void visit(ops::BatchNormOp& op) override; void visit(ops::DropoutOp& op) override; diff --git a/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp b/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp index 7b06637..d593b91 100644 --- a/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp +++ b/contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp @@ -15,6 +15,7 @@ #include "core/modelIR/operations/CappedReluOp.h" #include "core/modelIR/operations/TanhOp.h" #include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/DepthwiseConv2DOp.h" #include "core/modelIR/operations/FullyConnectedOp.h" #include "core/modelIR/operations/ConcatOp.h" @@ -681,5 +682,9 @@ void AclCppOpGenerator::visit(ops::SqueezeOp& op) { assert(false && "Unimplemented operation: Squeeze"); } +void AclCppOpGenerator::visit(mir::ops::ResizeOp& op) { + assert(false && "Unimplemented operation: Resize"); +} + } // namespace nnc diff --git a/contrib/nnc/passes/interpreter/Interpreter.cpp b/contrib/nnc/passes/interpreter/Interpreter.cpp index 877f945..c2e9611 100644 --- a/contrib/nnc/passes/interpreter/Interpreter.cpp +++ b/contrib/nnc/passes/interpreter/Interpreter.cpp @@ -30,6 +30,7 @@ #include "core/modelIR/operations/PoolOp.h" #include "core/modelIR/operations/VariableOp.h" #include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/EluOp.h" #include "core/modelIR/operations/ConcatOp.h" #include "core/modelIR/operations/BiasAddOp.h" @@ -282,4 +283,26 @@ void NNInterpreter::visit(ops::PadOp& op) { var(op.getId()) = Pad(input, op)(); } +void NNInterpreter::visit(ops::ResizeOp& op) { + mapByName(&op); + auto operand = op.getPrevNodes()[0]; + Tensor input(var(operand.op->getId())[operand.index]); + assert(input.getShape().rank() == 3 && "Must be rank 3 (for now)"); + switch (op.getMode()) { + case ops::ResizeOp::ResizeMethod::nearestNeighbor: { + auto scales = op.getScales(); + var(op.getId()) = Fill(op.getOutputShape(0), [&scales, &input, &op](const Index& id) { + const Index in_idx = {static_cast (lroundf(scales[0] * id.at(0))), + static_cast (lroundf(scales[1] * id.at(1))), + static_cast (lroundf(scales[2] * id.at(2)))}; + return input.at(in_idx); + })(); + break; + } + default: + assert(false && "Not supported Optype"); + } + +} + } // namespace nnc diff --git a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp index 9d93201..cf0c795 100644 --- a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp +++ b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp @@ -225,6 +225,10 @@ void ModelAnalyzer::visit(ops::ReshapeOp& op) { addOpDescr(&op, "reshape"); } +void ModelAnalyzer::visit(mir::ops::ResizeOp& op) { + assert(false && "Not implemented"); +} + void ModelAnalyzer::visit(ops::DropoutOp& op) { addOpDescr(&op, "dropout"); } diff --git a/contrib/nnc/passes/soft_backend/ModelAnalyzer.h b/contrib/nnc/passes/soft_backend/ModelAnalyzer.h index 77831e0..79cd275 100644 --- a/contrib/nnc/passes/soft_backend/ModelAnalyzer.h +++ b/contrib/nnc/passes/soft_backend/ModelAnalyzer.h @@ -98,6 +98,7 @@ public: void visit(mir::ops::VariableOp& op) override; void visit(mir::ops::ReluOp& op) override; void visit(mir::ops::ReshapeOp& op) override; + void visit(mir::ops::ResizeOp& op) override; void visit(mir::ops::ScaleOp& op) override; void visit(mir::ops::BatchNormOp& op) override; void visit(mir::ops::DropoutOp& op) override; diff --git a/contrib/nnc/passes/soft_backend/SBSerializer.cpp b/contrib/nnc/passes/soft_backend/SBSerializer.cpp index 7c09d23..7b51dd5 100644 --- a/contrib/nnc/passes/soft_backend/SBSerializer.cpp +++ b/contrib/nnc/passes/soft_backend/SBSerializer.cpp @@ -30,6 +30,7 @@ #include "core/modelIR/operations/CappedReluOp.h" #include "core/modelIR/operations/BiasAddOp.h" #include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/EluOp.h" #include "core/modelIR/operations/ReshapeOp.h" #include "core/modelIR/operations/BatchNormOp.h" @@ -320,4 +321,8 @@ void Serializer::visit(mir::ops::PadOp& op) { throw PassException("Not implemented yet"); } +void Serializer::visit(mir::ops::ResizeOp& op) { + throw PassException("Not implemented yet"); +} + } // namespace nnc diff --git a/contrib/nnc/passes/soft_backend/SBSerializer.h b/contrib/nnc/passes/soft_backend/SBSerializer.h index f5135a7..e1c3c62 100644 --- a/contrib/nnc/passes/soft_backend/SBSerializer.h +++ b/contrib/nnc/passes/soft_backend/SBSerializer.h @@ -52,6 +52,7 @@ public: void visit(mir::ops::VariableOp& op) override; void visit(mir::ops::ReluOp& op) override; void visit(mir::ops::ReshapeOp& op) override; + void visit(mir::ops::ResizeOp& op) override; void visit(mir::ops::ScaleOp& op) override; void visit(mir::ops::BatchNormOp& op) override; void visit(mir::ops::DropoutOp& op) override; diff --git a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp index 28f1dc5..0269df9 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp @@ -92,6 +92,7 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) { break; case BuiltinOperator_SOFTMAX: case BuiltinOperator_RESHAPE: + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: case BuiltinOperator_SQUEEZE: case BuiltinOperator_PAD: case BuiltinOperator_ADD: @@ -182,6 +183,10 @@ void TfliteImporter::walkOperator(const Operator* op) { case BuiltinOperator_RESHAPE: outputs = _opCreator->convertReshape(inputs, params, op->builtin_options_as()); break; + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: + outputs = _opCreator->convertResizeNN(inputs, params, + op->builtin_options_as()); + break; case BuiltinOperator_FULLY_CONNECTED: outputs = _opCreator->convertFullyConnected(inputs, params, op->builtin_options_as()); diff --git a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp index e61378c..566f72f 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp @@ -22,6 +22,7 @@ #include "core/modelIR/operations/DepthwiseConv2DOp.h" #include "core/modelIR/operations/FullyConnectedOp.h" #include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/CappedReluOp.h" #include "core/modelIR/operations/TanhOp.h" #include "core/modelIR/operations/ElementwiseOp.h" @@ -35,6 +36,10 @@ #include "core/modelIR/Tensor.h" #include "pass/PassException.h" +#include "core/modelIR/Tensor.h" +#include "core/modelIR/Shape.h" +#include "core/modelIR/ShapeRange.h" + using namespace nnc::mir; using namespace ::tflite; @@ -133,6 +138,7 @@ std::vector TFLiteOpCreator::convertReshape(InputOps inputs, In return outputs; } + std::vector TFLiteOpCreator::createTransposeConv(InputOps& inputs, InputParams& params, const ::tflite::TransposeConvOptions* opts) { @@ -142,6 +148,23 @@ TFLiteOpCreator::createTransposeConv(InputOps& inputs, InputParams& params, paddingMap[opts->padding()]); } +std::vector TFLiteOpCreator::convertResizeNN( + InputOps& inputs, InputParams& params, + const ::tflite::ResizeNearestNeighborOptions* opts) { + // TODO support aligned corners + assert(!opts->align_corners() && "Aligned corners not currently supported"); + + mir::Tensor out_shapes = mir::Tensor(*params[0].get()); + std::vector res_shape; + for (const auto& i : mir::ShapeRange(out_shapes.getShape())) + res_shape.push_back(out_shapes.at(i)); + res_shape.push_back(Shape::autoDim); + // assume no batch + return createOp(ActivationFunctionType_NONE, inputs[0]->getOutput(0), + ops::ResizeOp::ResizeMethod::nearestNeighbor, Shape(res_shape)); +} + + std::vector TFLiteOpCreator::createAdd(InputOps& inputs, InputParams&, const ::tflite::AddOptions* opts) { std::vector descriptors; @@ -235,8 +258,8 @@ mir::Operation* TFLiteOpCreator::addFusedActivation(mir::Operation* input, } } -std::vector TFLiteOpCreator::createSqueeze(InputOps inputs, InputParams params, - const ::tflite::SqueezeOptions* opts) { +std::vector TFLiteOpCreator::createSqueeze( + InputOps inputs, InputParams params, const ::tflite::SqueezeOptions* opts) { std::vector squeeze_dims{opts->squeeze_dims()->begin(), opts->squeeze_dims()->end()}; diff --git a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h index 8574e15..5b8a14d 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h @@ -69,6 +69,9 @@ public: std::vector convertFullyConnected(InputOps, InputParams, const ::tflite::FullyConnectedOptions*); + std::vector convertResizeNN(InputOps, InputParams, + const ::tflite::ResizeNearestNeighborOptions*); + std::vector createSqueeze(InputOps& inputs, InputParams& params, const ::tflite::SqueezeOptions* opts); diff --git a/contrib/nnc/unittests/core/ShapeInference.cpp b/contrib/nnc/unittests/core/ShapeInference.cpp index 9fefe88..924b1a0 100644 --- a/contrib/nnc/unittests/core/ShapeInference.cpp +++ b/contrib/nnc/unittests/core/ShapeInference.cpp @@ -17,8 +17,8 @@ #include "core/modelIR/Graph.h" #include "core/modelIR/ShapeInference.h" #include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/SqueezeOp.h" -#include "core/modelIR/Shape.h" #include "gtest/gtest.h" @@ -31,8 +31,9 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) { Shape input_shape{10, 2, 5}; Shape expected_shape{10, 1, 10}; + auto input = g.create("input", input_shape); - auto op = g.create("reshape", input->getOutput(0), Shape{10, 1, Shape::AUTO_DIM}); + auto op = g.create("reshape", input->getOutput(0), Shape{10, 1, Shape::autoDim}); op->setInputShape(0, input_shape); si.visit(*dynamic_cast(op)); @@ -40,6 +41,42 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) { ASSERT_EQ(expected_shape, op->getOutputShape(0)); } +TEST(ShapeInferenceTest, ResizeWithShape) { + Graph g; + ShapeInference si; + + Shape result_shape{10, 10, 3}; + + auto input = g.create("input", Shape{5, 5, 3}); + + auto op = g.create( + "Resize", input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor, + Shape{10, 10, Shape::autoDim} + ); + + g.accept(&si); + + ASSERT_EQ(result_shape, op->getOutputShape(0)); +} + +TEST(ShapeInferenceTest, ResizeWithScale) { + Graph g; + ShapeInference si; + + Shape result_shape{30, 10, 3}; + + auto input = g.create("input", Shape{5, 5, 3}); + + auto op = g.create( + "Resize", input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor, + std::vector{6, 2, 1} + ); + + g.accept(&si); + + ASSERT_EQ(result_shape, op->getOutputShape(0)); +} + TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) { Graph g; ShapeInference si; @@ -48,7 +85,7 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) { Shape result_shape_shrink{10, 20}; auto input = g.create("input", input_shape); - auto op = g.create("reshape", input->getOutput(0), Shape{10, Shape::AUTO_DIM}); + auto op = g.create("reshape", input->getOutput(0), Shape{10, Shape::autoDim}); op->setInputShape(0, input_shape); si.visit(*dynamic_cast(op)); @@ -63,7 +100,8 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) { Shape result_shape_expand{5, 10, 2, 2}; auto input = g.create("input", input_shape); - auto op = g.create("reshape", input->getOutput(0), Shape{5, Shape::AUTO_DIM, 2, 2}); + auto op = g.create("reshape", input->getOutput(0), + Shape{5, Shape::autoDim, 2, 2}); op->setInputShape(0, input_shape); si.visit(*dynamic_cast(op));