From 6ce2fe8f34e8f37f0522f865d9fa1a9f70a14afe Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 21 Nov 2018 19:23:36 +0300 Subject: [PATCH] [nnc] Add `shape` parameter to constructors of operations that cannot infer shapes of their outputs (#2338) * Add `shape` parameter to constructor of `VariableOp`; * Add `shape` parameter to constructor of `ReshapeOp`. Signed-off-by: Sergei Barannikov --- contrib/nnc/core/modelIR/Graph.cpp | 2 +- .../include/core/modelIR/operations/ReshapeOp.h | 4 +- .../include/core/modelIR/operations/VariableOp.h | 4 +- .../include/passes/common_frontend/shape_helper.h | 2 +- .../nnc/passes/caffe_frontend/caffe_op_creator.cpp | 13 ++--- .../nnc/passes/common_frontend/shape_helper.cpp | 4 +- .../nnc/passes/onnx_frontend/ONNXImporterImpl.cpp | 21 ++++---- contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp | 4 +- .../nnc/passes/tflite_frontend/tflite_importer.cpp | 6 +-- .../passes/tflite_frontend/tflite_op_creator.cpp | 16 +++--- contrib/nnc/tests/interpreter/graph_creator.cpp | 11 ++-- contrib/nnc/tests/soft_backend/CompileCPP.cpp | 4 +- contrib/nnc/unittests/core/Graph.cpp | 8 +-- contrib/nnc/unittests/core/NodeReplacer.cpp | 2 +- contrib/nnc/unittests/core/ShapeInference.cpp | 62 ++++++++++------------ contrib/nnc/unittests/core/operation.cpp | 10 ++-- .../nnc/unittests/soft_backend/CPPOperations.cpp | 13 ++--- contrib/nnc/unittests/soft_backend/Generator.cpp | 3 +- .../nnc/unittests/soft_backend/ModelAnalyzer.cpp | 3 +- 19 files changed, 83 insertions(+), 109 deletions(-) diff --git a/contrib/nnc/core/modelIR/Graph.cpp b/contrib/nnc/core/modelIR/Graph.cpp index 0ad9831..aa5981b 100644 --- a/contrib/nnc/core/modelIR/Graph.cpp +++ b/contrib/nnc/core/modelIR/Graph.cpp @@ -150,12 +150,12 @@ void Graph::replaceNode(const Operation* op, Operation* with) { } ops::VariableOp* Graph::replaceWithInputNode(const Operation* op) { - auto in = create(op->getName()); assert(op->getNumOutputs() <= 1 && "Only operations with single output value can be replaced with input node"); assert(op->getNextNodes().size() <= 1 && "Node with multiple outputs cannot be changed into input"); + auto in = create(op->getName(), op->getOutputShape(0)); replaceNode(op, in); //replaceNode adds all connections of original node, diff --git a/contrib/nnc/include/core/modelIR/operations/ReshapeOp.h b/contrib/nnc/include/core/modelIR/operations/ReshapeOp.h index 31010c3..ab9604b 100644 --- a/contrib/nnc/include/core/modelIR/operations/ReshapeOp.h +++ b/contrib/nnc/include/core/modelIR/operations/ReshapeOp.h @@ -22,7 +22,9 @@ namespace ops { class ReshapeOp : public Operation { public: - explicit ReshapeOp(const IODescriptor& arg) : Operation(Type::reshape, {arg}) {} + ReshapeOp(const IODescriptor& arg, const Shape& shape) : Operation(Type::reshape, {arg}) { + setOutputShape(0, shape); + } }; } // namespace ops diff --git a/contrib/nnc/include/core/modelIR/operations/VariableOp.h b/contrib/nnc/include/core/modelIR/operations/VariableOp.h index e1bec2c..ba122fc 100644 --- a/contrib/nnc/include/core/modelIR/operations/VariableOp.h +++ b/contrib/nnc/include/core/modelIR/operations/VariableOp.h @@ -25,7 +25,9 @@ namespace ops { class VariableOp : public Operation { public: - VariableOp() : Operation(Type::variable, {}) {} + explicit VariableOp(const Shape& shape) : Operation(Type::variable, {}) { + setOutputShape(0, shape); + } }; } // namespace ops diff --git a/contrib/nnc/include/passes/common_frontend/shape_helper.h b/contrib/nnc/include/passes/common_frontend/shape_helper.h index e7d6f84..ec1ac22 100644 --- a/contrib/nnc/include/passes/common_frontend/shape_helper.h +++ b/contrib/nnc/include/passes/common_frontend/shape_helper.h @@ -28,7 +28,7 @@ public: template static mir::Shape createShape(const Iterable &iter, std::size_t); - static mir::Shape &cutOffBatchDim(mir::Shape &shape); + static void cutOffBatchDim(mir::Shape& shape); }; template diff --git a/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp index 3965ec6..99979f3 100644 --- a/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp +++ b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp @@ -227,16 +227,15 @@ std::vector CaffeOpCreator::convertInput(const LayerParameter& lay { const auto& blob_name = layer.top(i); const auto& blob_shape = params.shape(num_shapes == 1 ? 0 : i); - auto variable = createOp(); - variable->setName(blob_name); Shape shape = ShapeHelper::createShape(blob_shape.dim(), blob_shape.dim_size()); - shape = ShapeHelper::cutOffBatchDim(shape); + ShapeHelper::cutOffBatchDim(shape); // WARNING! Temporary solution! Assuming that every 4D input will be used for a convolution, // so we change every 4D input from Caffe NCHW to Model IR HWC (batch is cut off earlier). // TODO: Implement a more consistent way of handling shapes within the model. if (shape.rank() == 3) shape = Shape{shape.dim(1), shape.dim(2), shape.dim(0)}; - variable->setOutputShape(0, shape); + auto variable = createOp(shape); + variable->setName(blob_name); descriptors.push_back(variable->getOutput(0)); } @@ -331,10 +330,9 @@ CaffeOpCreator::convertInnerProduct(const std::vector& inputs, const caffe::InnerProductParameter& opts) { // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize] // It is needed because Caffe InnerProduct layer takes NCHW input and flattens the CHW part. - auto reshape = createOp(inputs[0]); int32_t fc_input_size = static_cast( params[0]->getShape().numElements()) / opts.num_output(); - reshape->setOutputShape(0, {1, fc_input_size}); + auto reshape = createOp(inputs[0], Shape{1, fc_input_size}); auto fully_connected = createOp(reshape->getOutput(0), *params[0]); @@ -434,9 +432,8 @@ void CaffeOpCreator::checkReshape(const ReshapeParameter& opts, std::vector CaffeOpCreator::convertReshape(const std::vector& inputs, const caffe::ReshapeParameter& opts) { - auto reshape = createOp(inputs[0]); Shape new_shape = ShapeHelper::createShape(opts.shape().dim(), opts.shape().dim_size()); - reshape->setOutputShape(0, new_shape); + auto reshape = createOp(inputs[0], new_shape); return {reshape->getOutput(0)}; } diff --git a/contrib/nnc/passes/common_frontend/shape_helper.cpp b/contrib/nnc/passes/common_frontend/shape_helper.cpp index 69605ff..e06f7f9 100644 --- a/contrib/nnc/passes/common_frontend/shape_helper.cpp +++ b/contrib/nnc/passes/common_frontend/shape_helper.cpp @@ -23,7 +23,7 @@ namespace nnc { -mir::Shape &ShapeHelper::cutOffBatchDim(mir::Shape &shape) +void ShapeHelper::cutOffBatchDim(mir::Shape& shape) { if (shape.dim(0) != 1) { @@ -38,8 +38,6 @@ mir::Shape &ShapeHelper::cutOffBatchDim(mir::Shape &shape) shape.dim(i) = shape.dim(i + 1); } shape.resize(shape.rank() - 1); - - return shape; } } // namespace nnc diff --git a/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp b/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp index e6dfc10..dddfb50 100644 --- a/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp +++ b/contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp @@ -115,25 +115,24 @@ void ONNXImporterImpl::createGraphInputs() { assert(input.has_name()); auto name = input.name(); - // Every VariableOp relates to one graph input - auto op = _graph->create(name); - _opsForBlobsTheyOutput[name] = op; - + mir::Shape input_shape; if (onnx_tensors.find(name) != onnx_tensors.end()) { const onnx::TensorProto* onnx_tensor = onnx_tensors[name]; _inputTensors[name] = createTensor(onnx_tensor); - mir::Shape input_shape = ShapeHelper::createShape(onnx_tensor->dims(), - static_cast(onnx_tensor->dims_size())); - // WARNING! Temporary solution! - op->setOutputShape(0, input_shape); + input_shape = ShapeHelper::createShape(onnx_tensor->dims(), + static_cast(onnx_tensor->dims_size())); } else { assert(!name.compare("data")); _inputTensors[name] = createTensor(nullptr); // TODO: should we update op with special shape? - mir::Shape input_shape = ShapeHelper::createShape(std::vector(), 0); // WARNING! Temporary solution! - op->setOutputShape(0, input_shape); + input_shape = ShapeHelper::createShape(std::vector(), 0); } + + // Every VariableOp relates to one graph input + auto op = _graph->create(name, input_shape); + _opsForBlobsTheyOutput[name] = op; + std::cout << "Node name '" << name << "' added\n"; // < std::endl; } } @@ -238,7 +237,7 @@ mir::Graph *ONNXImporterImpl::createIR() { if (!outputs.size()) { // FIXME: it's for debugging only for (auto name : onnxNode.output()) { - auto node = _graph->create(name); + auto node = _graph->create(name, mir::Shape{}); outputs.push_back(node); } } else { diff --git a/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp b/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp index 441fc16..220e9fd 100644 --- a/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp +++ b/contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp @@ -63,9 +63,7 @@ std::vector ONNXOpCreator::createSoftmax(InputOps inputs, int axis) } std::vector ONNXOpCreator::createReshape(Operation* inputData, Shape outputShape) { - auto outputs = createOp(inputData->getOutput(0)); - outputs[0]->setOutputShape(0, outputShape); - return outputs; + return createOp(inputData->getOutput(0), outputShape); } std::vector ONNXOpCreator::createRelu(InputOps inputs) { diff --git a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp index d96fe5a..14ccd66 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp @@ -135,14 +135,12 @@ void TfliteImporter::walkSubGraph(const SubGraph* s) { for (auto i : *s->inputs()) { const Tensor* t = (*s->tensors())[i]; - auto node = _graph->create(t->name()->c_str()); - _opsForTensorsTheyOutput[i] = node; - Shape inputShape = ShapeHelper::createShape(*t->shape(), t->shape()->size()); // So far we assume that if the first dimension is equal to 1, // then it is the batch dimension and should be ignored ShapeHelper::cutOffBatchDim(inputShape); - node->setOutputShape(0, inputShape); + auto node = _graph->create(t->name()->c_str(), inputShape); + _opsForTensorsTheyOutput[i] = node; } for (auto op: *(s->operators())) diff --git a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp index cf21918..a1bb93e 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp @@ -123,13 +123,11 @@ std::vector TFLiteOpCreator::createSoftmax(InputOps inputs, Inp std::vector TFLiteOpCreator::convertReshape(InputOps inputs, InputParams params, const ReshapeOptions* opts) { - auto outputs = createOp(ActivationFunctionType_NONE, inputs[0]->getOutput(0)); - // TODO: we should also support "-1" values in new_shape, which means that correct // shape values must be calculated. Better do it in the shape inference module. - Shape newShape = ShapeHelper::createShape(*opts->new_shape(), opts->new_shape()->size()); - - outputs[0]->setOutputShape(0, newShape); + Shape new_shape = ShapeHelper::createShape(*opts->new_shape(), opts->new_shape()->size()); + auto outputs = createOp(ActivationFunctionType_NONE, inputs[0]->getOutput(0), + new_shape); return outputs; } @@ -179,10 +177,10 @@ std::vector TFLiteOpCreator::convertFullyConnected(InputOps& inputs, InputParams& params, const FullyConnectedOptions* opts) { - // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize] - auto outputs = createOp(ActivationFunctionType_NONE, inputs[0]->getOutput(0)); - int32_t fcInputSize = params[0]->getShape().dim(0); - outputs[0]->setOutputShape(0, {1, fcInputSize}); + // Add Reshape operation to make sure the input for FC operation has shape [1, fc_input_size] + int32_t fc_input_size = params[0]->getShape().dim(0); + auto outputs = createOp(ActivationFunctionType_NONE, inputs[0]->getOutput(0), + Shape{1, fc_input_size}); auto fc_outputs = createOp(ActivationFunctionType_NONE, outputs[0]->getOutput(0), *params[0]); diff --git a/contrib/nnc/tests/interpreter/graph_creator.cpp b/contrib/nnc/tests/interpreter/graph_creator.cpp index 864be5d..011700a 100644 --- a/contrib/nnc/tests/interpreter/graph_creator.cpp +++ b/contrib/nnc/tests/interpreter/graph_creator.cpp @@ -76,9 +76,7 @@ static Operation* createConcatenation(std::unique_ptr& g, static Operation* createReshape(std::unique_ptr& g, const std::vector& inputs, const opinfo::OperatorInfo* opInfo) { - auto op = g->create("y", inputs[0]); - op->setOutputShape(0, getShapeParam(opInfo, 0)); - return op; + return g->create("y", inputs[0], getShapeParam(opInfo, 0)); } static Operation* createReLU(std::unique_ptr& g, @@ -140,14 +138,11 @@ std::unique_ptr make_graph(const opinfo::OperatorInfo* opInfo) { std::unique_ptr g(new Graph()); std::vector inputs; + // Create inputs for (unsigned int i = 0; i < opInfo->inputs()->size(); ++i) { - // Create i-th input node - auto inputOp = g->create("x" + std::to_string(i)); - - // Set input shape auto inputShapeIter = opInfo->inputs()->Get(i)->shape()->dims(); Shape inputShape = ShapeHelper::createShape(*inputShapeIter, inputShapeIter->size()); - inputOp->setOutputShape(0, inputShape); + auto inputOp = g->create("x" + std::to_string(i), inputShape); inputs.push_back(inputOp->getOutput(0)); } diff --git a/contrib/nnc/tests/soft_backend/CompileCPP.cpp b/contrib/nnc/tests/soft_backend/CompileCPP.cpp index 270f5a2..ba49119 100644 --- a/contrib/nnc/tests/soft_backend/CompileCPP.cpp +++ b/contrib/nnc/tests/soft_backend/CompileCPP.cpp @@ -49,9 +49,7 @@ using namespace nnc::mir; void fillGraph(Graph &g) { Shape inputShape{1, 2, 3}; - Operation* inputOp = g.create("in"); - inputOp->setOutputShape(0, inputShape); - + Operation* inputOp = g.create("in", inputShape); Operation* outputOp = g.create("out", inputOp->getOutput(0)); g.markOutput(outputOp); diff --git a/contrib/nnc/unittests/core/Graph.cpp b/contrib/nnc/unittests/core/Graph.cpp index ec9c8ca..02734e6 100644 --- a/contrib/nnc/unittests/core/Graph.cpp +++ b/contrib/nnc/unittests/core/Graph.cpp @@ -32,7 +32,7 @@ public: TEST(Graph, ReplaceInputs) { auto g = new Graph; - auto n1 = g->create("op1"); + auto n1 = g->create("op1", Shape{}); auto n2 = g->create("op2", n1->getOutput(0)); auto n3 = g->create("op3", n2->getOutput(0)); auto n4 = g->create("op4", n2->getOutput(0)); @@ -40,6 +40,7 @@ TEST(Graph, ReplaceInputs) { std::vector{n3->getOutput(0), n4->getOutput(0)}, 0); + n4->setOutputShape(0, Shape{}); g->replaceInputNodes({"op1", "op4"}); std::stringstream ss; @@ -57,7 +58,7 @@ TEST(Graph, ReplaceOutputs) { auto g = new Graph; - auto n1 = g->create("op1"); + auto n1 = g->create("op1", Shape{}); auto n2 = g->create("op2", n1->getOutput(0)); auto n3 = g->create("op3", n2->getOutput(0)); auto n4 = g->create("op4", n2->getOutput(0)); @@ -75,11 +76,12 @@ TEST(Graph, ReplaceOutputs) { TEST(Graph, ReplaceOutputNodeWithInput) { auto g = new Graph; - auto n1 = g->create("op1"); + auto n1 = g->create("op1", Shape{}); auto n2 = g->create("op2", n1->getOutput(0)); g->markOutput(n2); + n2->setOutputShape(0, Shape{}); auto in2 = g->replaceWithInputNode(n2); std::vector expectedInputs{in2, n1}; diff --git a/contrib/nnc/unittests/core/NodeReplacer.cpp b/contrib/nnc/unittests/core/NodeReplacer.cpp index a1235e4..f01b1bf 100644 --- a/contrib/nnc/unittests/core/NodeReplacer.cpp +++ b/contrib/nnc/unittests/core/NodeReplacer.cpp @@ -31,7 +31,7 @@ public: TEST(NodeMutatorTest, SimpleChainTest) { auto g = new Graph; - auto n1 = g->create("op1"); + auto n1 = g->create("op1", Shape{}); auto n2 = g->create("op2", n1->getOutput(0)); auto n3 = g->create("op3", n2->getOutput(0)); auto n4 = g->create("op4", n2->getOutput(0)); diff --git a/contrib/nnc/unittests/core/ShapeInference.cpp b/contrib/nnc/unittests/core/ShapeInference.cpp index 84e654a..9fefe88 100644 --- a/contrib/nnc/unittests/core/ShapeInference.cpp +++ b/contrib/nnc/unittests/core/ShapeInference.cpp @@ -28,44 +28,46 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) { Graph g; ShapeInference si; - Shape resultShape{ 10, 1, 10 }; + Shape input_shape{10, 2, 5}; + Shape expected_shape{10, 1, 10}; - auto input = g.create("input"); - input->setOutputShape(0, Shape{ 10, 2, 5} ); - - auto op = g.create("reshape", input->getOutput(0)); - op->setInputShape( 0, Shape{10, 2, 5} ); - op->setOutputShape(0, Shape{10, 1, Shape::AUTO_DIM} ); + auto input = g.create("input", input_shape); + auto op = g.create("reshape", input->getOutput(0), Shape{10, 1, Shape::AUTO_DIM}); + op->setInputShape(0, input_shape); si.visit(*dynamic_cast(op)); - ASSERT_EQ(resultShape, op->getOutputShape(0)); + ASSERT_EQ(expected_shape, op->getOutputShape(0)); } -TEST(ShapeInferenceTest, ReshapeAutoDimensionVaryRank) { +TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) { Graph g; ShapeInference si; - Shape inputShape{10, 2, 10}; - Shape resultShapeShrink{10, 20}; - Shape resultShapeExpand{5, 10, 2, 2}; + Shape input_shape{10, 2, 10}; + Shape result_shape_shrink{10, 20}; - auto input = g.create("input"); + auto input = g.create("input", input_shape); + auto op = g.create("reshape", input->getOutput(0), Shape{10, Shape::AUTO_DIM}); + op->setInputShape(0, input_shape); - input->setOutputShape(0, inputShape); + si.visit(*dynamic_cast(op)); + ASSERT_EQ(result_shape_shrink, op->getOutputShape(0)); +} - auto op = g.create("reshape", input->getOutput(0)); - op->setInputShape( 0, inputShape); +TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) { + Graph g; + ShapeInference si; - // test shrink - op->setOutputShape(0, Shape{10, Shape::AUTO_DIM}); - si.visit(*dynamic_cast(op)); - ASSERT_EQ(resultShapeShrink, op->getOutputShape(0)); + Shape input_shape{10, 2, 10}; + Shape result_shape_expand{5, 10, 2, 2}; + + auto input = g.create("input", input_shape); + auto op = g.create("reshape", input->getOutput(0), Shape{5, Shape::AUTO_DIM, 2, 2}); + op->setInputShape(0, input_shape); - // test expansion - op->setOutputShape(0, Shape{5, Shape::AUTO_DIM, 2, 2}); si.visit(*dynamic_cast(op)); - ASSERT_EQ(resultShapeExpand, op->getOutputShape(0)); + ASSERT_EQ(result_shape_expand, op->getOutputShape(0)); } TEST(ShapeInferenceTest, SqueezeTestAllDims) { @@ -75,9 +77,7 @@ TEST(ShapeInferenceTest, SqueezeTestAllDims) { Shape input_shape{1, 2, 1, 4}; Shape expected_shape{2, 4}; - auto input = g.create("input"); - input->setOutputShape(0, input_shape); - + auto input = g.create("input", input_shape); auto sq1 = g.create("squeeze_1", input->getOutput(0), std::vector{}); g.accept(&si); @@ -92,10 +92,7 @@ TEST(ShapeInferenceTest, SqueezeTestSpecificDims) { Shape input_shape{1, 2, 1, 4}; Shape expected_shape{1, 2, 4}; - auto input = g.create("input"); - input->setOutputShape(0, input_shape); - - + auto input = g.create("input", input_shape); auto sq1 = g.create("squeeze_1", input->getOutput(0), std::vector{2}); g.accept(&si); @@ -110,10 +107,7 @@ TEST(ShapeInferenceTest, SqueezeTestScalarResult) { Shape input_shape{1, 1, 1, 1}; Shape expected_shape{1}; - auto input = g.create("input"); - input->setOutputShape(0, input_shape); - - + auto input = g.create("input", input_shape); auto sq1 = g.create("squeeze_1", input->getOutput(0), std::vector{}); g.accept(&si); diff --git a/contrib/nnc/unittests/core/operation.cpp b/contrib/nnc/unittests/core/operation.cpp index b0fff70..d26d6b0 100644 --- a/contrib/nnc/unittests/core/operation.cpp +++ b/contrib/nnc/unittests/core/operation.cpp @@ -26,9 +26,9 @@ using namespace nnc::mir; TEST(Operation, ConnectionTest) { - auto op1 = new ops::VariableOp(); + auto op1 = new ops::VariableOp(Shape{}); op1->setId(0); - auto op2 = new ops::ReshapeOp(op1->getOutput(0)); + auto op2 = new ops::ReshapeOp(op1->getOutput(0), Shape{}); op2->setId(1); ASSERT_EQ(op1->getId(), op2->getPrevNodes()[0].op->getId()); @@ -41,7 +41,7 @@ TEST(Operation, InputOutputShapeTest) { Shape inShape{1,2,3}; Shape outShape{3,2,1}; - ops::VariableOp input; + ops::VariableOp input(Shape{}); ops::SoftmaxOp op(input.getOutput(0), 0); op.setInputShape(0, inShape); op.setOutputShape(0, outShape); @@ -53,7 +53,7 @@ TEST(Operation, InputOutputShapeTest) { TEST(Operation, SoftmaxAxisTest) { Shape inShape{1,2,3}; - ops::VariableOp input; + ops::VariableOp input(Shape{}); ops::SoftmaxOp op_1(input.getOutput(0), 1); op_1.setInputShape(0, inShape); @@ -71,7 +71,7 @@ TEST(Operation, SoftmaxAxisTest) { TEST(Operation, ConcatAxisTest) { Shape inShape{1,2,3}; - ops::VariableOp input1, input2; + ops::VariableOp input1(Shape{}), input2(Shape{}); ops::ConcatOp op_1({input1.getOutput(0), input2.getOutput(0)}, 1); op_1.setInputShape(0, inShape); diff --git a/contrib/nnc/unittests/soft_backend/CPPOperations.cpp b/contrib/nnc/unittests/soft_backend/CPPOperations.cpp index 3da4b24..f0b20f7 100644 --- a/contrib/nnc/unittests/soft_backend/CPPOperations.cpp +++ b/contrib/nnc/unittests/soft_backend/CPPOperations.cpp @@ -102,16 +102,13 @@ mir::Operation* fillGraph(mir::Graph& g, function& inputs)> opGen, const vector>& inputNTensors) { + // Create inputs std::vector inputs; int numInputs = inputNTensors.size(); for (int i = 0; i < numInputs; ++i) { - // Create i-th input node - auto inputOp = g.create("x" + std::to_string(i)); - - // Set input shape - inputOp->setOutputShape(0, inputNTensors[i]->getShape()); - + auto inputOp = g.create("x" + std::to_string(i), + inputNTensors[i]->getShape()); inputs.push_back(inputOp->getOutput(0)); } @@ -648,9 +645,7 @@ TEST(cpp_operations_test, reshape) vector> inputNTensors(1); fillTensors(inputNTensors[0], aInputTensor, inputShapeData, 1.0f); auto opGenerator = [nOutputShape](mir::Graph& g, const std::vector& inputs) { - auto op = g.create("y", inputs[0]); - op->setOutputShape(0, nOutputShape); - return op; + return g.create("y", inputs[0], nOutputShape); }; createAndRunTestGraph(opGenerator, reshape, inputNTensors, aInputTensor); diff --git a/contrib/nnc/unittests/soft_backend/Generator.cpp b/contrib/nnc/unittests/soft_backend/Generator.cpp index 6f64d2c..e226dd9 100644 --- a/contrib/nnc/unittests/soft_backend/Generator.cpp +++ b/contrib/nnc/unittests/soft_backend/Generator.cpp @@ -87,8 +87,7 @@ TEST(Generator, check_generator_call) cli::CommandLine::getParser()->parseCommandLine(argc, argv, false); nnc::mir::Graph g; - Operation* input = g.create("input"); - input->setOutputShape(0, Shape({1,2,3,4})); + Operation* input = g.create("input", Shape{1, 2, 3, 4}); Operation* output = g.create("output", input->getOutput(0)); // test that generator creates output dir and files diff --git a/contrib/nnc/unittests/soft_backend/ModelAnalyzer.cpp b/contrib/nnc/unittests/soft_backend/ModelAnalyzer.cpp index f6f3f31..12dfa08 100644 --- a/contrib/nnc/unittests/soft_backend/ModelAnalyzer.cpp +++ b/contrib/nnc/unittests/soft_backend/ModelAnalyzer.cpp @@ -45,8 +45,7 @@ TEST(ModelAnalyzer, linearization) { * \ / * [join] */ - Operation* input = g.create("input"); - input->setOutputShape(0, {1,2,3}); + Operation* input = g.create("input", Shape{1, 2, 3}); Operation* head1 = g.create("head1", input->getOutput(0)); Operation* head2 = g.create("head2", input->getOutput(0)); Operation* tail1 = g.create("tail1", head1->getOutput(0)); -- 2.7.4