From: Сергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 Date: Fri, 30 Aug 2019 07:06:58 +0000 (+0900) Subject: [mir] Remove operation names (#7016) X-Git-Tag: accepted/tizen/unified/20190903.052428~33 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d3dbdf415a3c2d6231c5096410adcf7e81d07fbe;p=platform%2Fcore%2Fml%2Fnnfw.git [mir] Remove operation names (#7016) Names on operations were difficult to maintain. Signed-off-by: Sergei Barannikov --- diff --git a/compiler/mir/include/mir/Graph.h b/compiler/mir/include/mir/Graph.h index 7bc152a..8c79ad4 100644 --- a/compiler/mir/include/mir/Graph.h +++ b/compiler/mir/include/mir/Graph.h @@ -40,18 +40,6 @@ public: virtual ~Graph(); - // TODO Remove after eliminating all uses. - template - typename std::enable_if::value, Operation *>::type - create(U &&name, Args &&... args) - { - auto op = new T(std::forward(args)...); - op->setId(_last_node_id++); - op->setName(std::forward(name)); - registerOp(op); - return op; - } - template Operation *create(Args &&... args) { auto op = new T(std::forward(args)...); @@ -68,7 +56,6 @@ public: assert(inputs.size() == old_op->getNumInputs()); auto op = old_op->copyWithInputs(inputs); op->setId(_last_node_id++); - op->setName(old_op->getName()); registerOp(op); return op; } diff --git a/compiler/mir/include/mir/Operation.h b/compiler/mir/include/mir/Operation.h index 972b00a..246b706 100644 --- a/compiler/mir/include/mir/Operation.h +++ b/compiler/mir/include/mir/Operation.h @@ -130,9 +130,6 @@ public: std::size_t getId() const { return _id; } void setId(std::size_t id) { _id = id; } - const std::string &getName() const { return _name; } - void setName(const std::string &name) { _name = name; } - std::size_t getNumInputs() const { return _inputs.size(); } std::size_t getNumOutputs() const { return _outputs.size(); } @@ -185,7 +182,6 @@ protected: private: Type _type; std::size_t _id = std::numeric_limits::max(); - std::string _name; std::deque _inputs; std::deque _outputs; }; diff --git a/compiler/mir/include/mir/ir_dot_node_info.h b/compiler/mir/include/mir/ir_dot_node_info.h index 5320bcf..d8bda1a 100644 --- a/compiler/mir/include/mir/ir_dot_node_info.h +++ b/compiler/mir/include/mir/ir_dot_node_info.h @@ -59,7 +59,7 @@ public: DotIrNodeInfo() = default; - DotIrNodeInfo &withType(const std::string &type_name, const std::string &node_name); + DotIrNodeInfo &withType(const std::string &type_name); DotIrNodeInfo &withInShapes(Shapes &&in_shapes); DotIrNodeInfo &withOutShapes(Shapes &&out_shapes); @@ -90,7 +90,6 @@ private: void addPipeIfNeeded(std::stringstream &ss, bool needed, bool &need_pipe) const; std::string _type_name; - std::string _node_name; Shapes _in_shapes; Shapes _out_shapes; diff --git a/compiler/mir/src/Graph.cpp b/compiler/mir/src/Graph.cpp index f7de3a3..4cf57dd 100644 --- a/compiler/mir/src/Graph.cpp +++ b/compiler/mir/src/Graph.cpp @@ -121,7 +121,7 @@ ops::InputOp *Graph::replaceWithInputNode(Operation *op) assert(op->getNumOutputs() == 1 && "Only operations with single output value can be replaced with input node"); - auto in = create(op->getName(), op->getOutputShape(0)); + auto in = create(op->getOutputShape(0)); replaceNode(op, in); return dynamic_cast(in); @@ -135,7 +135,7 @@ void Graph::replaceInputNodes(const std::vector &new_inputs) for (auto &op : _ops) { - if (new_input_set.count(op->getName()) != 0) + if (op->getNumOutputs() == 1 && new_input_set.count(op->getOutput(0)->getName()) != 0) { ops_to_replace.push_back(op); } diff --git a/compiler/mir/src/IrDotDumper.cpp b/compiler/mir/src/IrDotDumper.cpp index 4c50a5a..404bf92 100644 --- a/compiler/mir/src/IrDotDumper.cpp +++ b/compiler/mir/src/IrDotDumper.cpp @@ -45,7 +45,7 @@ static std::vector getOutputShapes(const Operation &op) void IrDotDumper::visit(ops::AvgPool2DOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("AvgPool2D", op.getName()) + .withType("AvgPool2D") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withShape("Window size", Shape(op.getWindowSize())) @@ -59,7 +59,7 @@ void IrDotDumper::visit(ops::AvgPool2DOp &op) void IrDotDumper::visit(ops::CappedReluOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("CappedRelu", op.getName()) + .withType("CappedRelu") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withMisc("Cap", op.getCap()); @@ -70,7 +70,7 @@ void IrDotDumper::visit(ops::CappedReluOp &op) void IrDotDumper::visit(ops::ConcatOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("Concat", op.getName()) + .withType("Concat") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withMisc("Axis", op.getAxis()); @@ -81,7 +81,7 @@ void IrDotDumper::visit(ops::ConcatOp &op) void IrDotDumper::visit(ops::Conv2DOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("Conv2D", op.getName()) + .withType("Conv2D") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withStride(op.getStrides()) @@ -94,7 +94,7 @@ void IrDotDumper::visit(ops::Conv2DOp &op) void IrDotDumper::visit(ops::DepthwiseConv2DOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("DepthwiseConv2D", op.getName()) + .withType("DepthwiseConv2D") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withStride(op.getStrides()) @@ -107,7 +107,7 @@ void IrDotDumper::visit(ops::DepthwiseConv2DOp &op) void IrDotDumper::visit(ops::FullyConnectedOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("FullyConnected", op.getName()) + .withType("FullyConnected") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -117,7 +117,7 @@ void IrDotDumper::visit(ops::FullyConnectedOp &op) void IrDotDumper::visit(ops::MaxPool2DOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("MaxPool2D", op.getName()) + .withType("MaxPool2D") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withShape("Window size", Shape(op.getWindowSize())) @@ -131,7 +131,7 @@ void IrDotDumper::visit(ops::MaxPool2DOp &op) void IrDotDumper::visit(ops::SoftmaxOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("Softmax", op.getName()) + .withType("Softmax") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withMisc("Axis", op.getAxis()); @@ -142,7 +142,7 @@ void IrDotDumper::visit(ops::SoftmaxOp &op) void IrDotDumper::visit(ops::PoolOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("Pool2D", op.getName()) + .withType("Pool2D") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withShape("PoolWindow", op.getWindowShape()) @@ -157,7 +157,7 @@ void IrDotDumper::visit(ops::PoolOp &op) void IrDotDumper::visit(ops::ReluOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("ReLU", op.getName()) + .withType("ReLU") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -167,7 +167,7 @@ void IrDotDumper::visit(ops::ReluOp &op) void IrDotDumper::visit(ops::ReshapeOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("Reshape", op.getName()) + .withType("Reshape") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -177,7 +177,7 @@ void IrDotDumper::visit(ops::ReshapeOp &op) void IrDotDumper::visit(ops::InputOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("Input", op.getName()) + .withType("Input") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -187,7 +187,7 @@ void IrDotDumper::visit(ops::InputOp &op) void IrDotDumper::visit(ops::ConstantOp &op) { auto node_info = DotIrNodeInfo() - .withType("Constant", op.getName()) + .withType("Constant") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -197,7 +197,7 @@ void IrDotDumper::visit(ops::ConstantOp &op) void IrDotDumper::visit(ops::SliceOp &op) { auto node_info = DotIrNodeInfo() - .withType("SliceOp", op.getName()) + .withType("SliceOp") .withInShapes(getInputShapes(op)) .withShape("Starts", op.getStarts()) .withShape("Sizes", op.getSizes()) @@ -209,7 +209,7 @@ void IrDotDumper::visit(ops::SliceOp &op) void IrDotDumper::visit(ops::DeConv2DOp &op) { auto node_info = DotIrNodeInfo() - .withType("DeConv2D", op.getName()) + .withType("DeConv2D") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withPadType(op.getPaddingType()) @@ -221,7 +221,7 @@ void IrDotDumper::visit(ops::DeConv2DOp &op) void IrDotDumper::visit(ops::EluOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("EluOp", op.getName()) + .withType("EluOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withMisc("Alpha", op.getAlpha()); @@ -230,7 +230,7 @@ void IrDotDumper::visit(ops::EluOp &op) void IrDotDumper::visit(ops::TanhOp &op) { auto nodeInfo = DotIrNodeInfo() - .withType("TanhOp", op.getName()) + .withType("TanhOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -240,7 +240,7 @@ void IrDotDumper::visit(ops::TanhOp &op) void IrDotDumper::visit(ops::SqueezeOp &op) { auto node_info = DotIrNodeInfo() - .withType("SqueezeOp", op.getName()) + .withType("SqueezeOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -255,7 +255,7 @@ void IrDotDumper::visit(ops::SqueezeOp &op) void mir::IrDotDumper::visit(ops::PadOp &op) { auto node_info = DotIrNodeInfo() - .withType("PadOp", op.getName()) + .withType("PadOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -265,7 +265,7 @@ void mir::IrDotDumper::visit(ops::PadOp &op) void IrDotDumper::visit(ops::SqrtOp &op) { auto node_info = DotIrNodeInfo() - .withType("Sqrt", op.getName()) + .withType("Sqrt") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -279,7 +279,7 @@ void IrDotDumper::visit(ops::ReduceOp &op) auto node_info = DotIrNodeInfo() - .withType("ReduceOp", op.getName()) + .withType("ReduceOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims @@ -293,7 +293,7 @@ void IrDotDumper::visit(ops::ReduceMeanOp &op) { auto node_info = DotIrNodeInfo() - .withType("ReduceMeanOp", op.getName()) + .withType("ReduceMeanOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims @@ -308,7 +308,7 @@ void IrDotDumper::visit(ops::ResizeOp &op) {ops::ResizeOp::ResizeMethod::nearestNeighbor, "nearestNeighbor"}}; auto node_info = DotIrNodeInfo() - .withType("Resize", op.getName()) + .withType("Resize") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withMisc("Mode", modes.at(op.getMode())); @@ -320,7 +320,7 @@ void IrDotDumper::visit(ops::ResizeOp &op) void IrDotDumper::visit(ops::TransposeOp &op) { auto node_info = DotIrNodeInfo() - .withType("TransposeOp", op.getName()) + .withType("TransposeOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -330,7 +330,7 @@ void IrDotDumper::visit(ops::TransposeOp &op) void IrDotDumper::visit(ops::GatherOp &op) { auto node_info = DotIrNodeInfo() - .withType("GatherOp", op.getName()) + .withType("GatherOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -340,7 +340,7 @@ void IrDotDumper::visit(ops::GatherOp &op) void IrDotDumper::visit(ops::SigmoidOp &op) { auto node_info = DotIrNodeInfo() - .withType("SigmoidOp", op.getName()) + .withType("SigmoidOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)); @@ -350,7 +350,7 @@ void IrDotDumper::visit(ops::SigmoidOp &op) void IrDotDumper::visit(mir::ops::LeakyReluOp &op) { auto node_info = DotIrNodeInfo() - .withType("LeakyReluOp", op.getName()) + .withType("LeakyReluOp") .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) .withMisc("alpha", op.getAlpha()); @@ -360,8 +360,7 @@ void IrDotDumper::visit(mir::ops::LeakyReluOp &op) void IrDotDumper::visit(ops::OutputOp &op) { - auto node_info = - DotIrNodeInfo().withType("OutputOp", op.getName()).withInShapes(getInputShapes(op)); + auto node_info = DotIrNodeInfo().withType("OutputOp").withInShapes(getInputShapes(op)); _dot_builder.updateWithOp(&op, node_info); } diff --git a/compiler/mir/src/ir_dot_node_info.cpp b/compiler/mir/src/ir_dot_node_info.cpp index bfe3376..a365c6a 100644 --- a/compiler/mir/src/ir_dot_node_info.cpp +++ b/compiler/mir/src/ir_dot_node_info.cpp @@ -26,10 +26,9 @@ template <> DotIrNodeInfo::Stringable::Stringable(std::string val) : _val(std::m template <> DotIrNodeInfo::Stringable::Stringable(const char *val) : _val(val) {} -DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &type_name, const std::string &node_name) +DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &type_name) { this->_type_name = type_name; - this->_node_name = node_name; return *this; } @@ -85,19 +84,7 @@ std::string DotIrNodeInfo::getLabel() const { std::stringstream ss; - ss << "{"; - - // Node type and name - ss << (!_type_name.empty() ? _type_name : "TYPE_NOT_SET") << ": " - << (!_node_name.empty() ? _node_name : "NAME_NOT_SET"); - - if (_type_name.empty()) - { - std::cout << "Warning: Node type is not set for " - << (_node_name.empty() ? "one of the nodes" : "node " + _node_name) << std::endl; - } - - ss << " | "; + ss << "{" << _type_name << " | "; // Note inputs and output shapes ss << "{{"; diff --git a/compiler/mir/unittests/Graph.cpp b/compiler/mir/unittests/Graph.cpp index 4484017..68a2324 100644 --- a/compiler/mir/unittests/Graph.cpp +++ b/compiler/mir/unittests/Graph.cpp @@ -16,11 +16,11 @@ class DumpVisitor : public Visitor public: DumpVisitor(std::ostream &s) : _s(s) {} - void visit(ops::InputOp &op) override { _s << "i" << op.getName(); }; + void visit(ops::InputOp &op) override { _s << "i" << std::to_string(op.getId()); }; - void visit(ops::ReluOp &op) override { _s << "r" << op.getName(); } + void visit(ops::ReluOp &op) override { _s << "r" << std::to_string(op.getId()); } - void visit(ops::ConcatOp &op) override { _s << "c" << op.getName(); } + void visit(ops::ConcatOp &op) override { _s << "c" << std::to_string(op.getId()); } std::ostream &_s; }; @@ -29,12 +29,14 @@ TEST(Graph, ReplaceInputs) { auto g = new Graph; - auto n1 = g->create("op1", Shape{1}); - auto n2 = g->create("op2", n1->getOutput(0)); - auto n3 = g->create("op3", n2->getOutput(0)); - auto n4 = g->create("op4", n2->getOutput(0)); + auto n1 = g->create(Shape{1}); + auto n2 = g->create(n1->getOutput(0)); + auto n3 = g->create(n2->getOutput(0)); + auto n4 = g->create(n2->getOutput(0)); std::vector concat_inputs{n3->getOutput(0), n4->getOutput(0)}; - auto n5 = g->create("op5", concat_inputs, 0); + auto n5 = g->create(concat_inputs, 0); + n1->getOutput(0)->setName("op1"); + n4->getOutput(0)->setName("op4"); g->replaceInputNodes({"op1", "op4"}); @@ -43,7 +45,7 @@ TEST(Graph, ReplaceInputs) g->accept(&d); auto str = ss.str(); - ASSERT_TRUE(str == "iop1iop4rop2rop3cop5" || str == "iop4iop1rop2rop3cop5") << "str = " << str; + ASSERT_TRUE(str == "i5i6r1r2c4" || str == "i6i5r1r2c4") << "str = " << str; delete g; }; @@ -51,9 +53,9 @@ TEST(Graph, ReplaceOutputNodeWithInput) { auto g = new Graph; - auto n1 = g->create("op1", Shape{}); - auto n2 = g->create("op2", n1->getOutput(0)); - auto n3 = g->create("op3", n2->getOutput(0)); + auto n1 = g->create(Shape{}); + auto n2 = g->create(n1->getOutput(0)); + auto n3 = g->create(n2->getOutput(0)); auto in2 = g->replaceWithInputNode(n2); @@ -61,4 +63,5 @@ TEST(Graph, ReplaceOutputNodeWithInput) ASSERT_EQ(g->getInputs(), expected_inputs); delete g; } -} + +} // namespace diff --git a/compiler/mir/unittests/NodeReplacer.cpp b/compiler/mir/unittests/NodeReplacer.cpp index ca78150..f2e702f 100644 --- a/compiler/mir/unittests/NodeReplacer.cpp +++ b/compiler/mir/unittests/NodeReplacer.cpp @@ -15,11 +15,11 @@ class DumpVisitor : public Visitor public: DumpVisitor(std::ostream &s) : _s(s) {} - void visit(ops::InputOp &op) override { _s << "i" << op.getName(); }; + void visit(ops::InputOp &op) override { _s << "i" << std::to_string(op.getId()); }; - void visit(ops::ReluOp &op) override { _s << "r" << op.getName(); } + void visit(ops::ReluOp &op) override { _s << "r" << std::to_string(op.getId()); } - void visit(ops::ConcatOp &op) override { _s << "c" << op.getName(); } + void visit(ops::ConcatOp &op) override { _s << "c" << std::to_string(op.getId()); } std::ostream &_s; }; @@ -27,11 +27,11 @@ public: TEST(NodeMutatorTest, SimpleChainTest) { auto g = new Graph; - auto n1 = g->create("op1", Shape{}); - auto n2 = g->create("op2", n1->getOutput(0)); - auto n3 = g->create("op3", n2->getOutput(0)); - auto n4 = g->create("op4", n2->getOutput(0)); - auto n5 = g->create("op5", n1->getOutput(0)); + auto n1 = g->create(Shape{}); + auto n2 = g->create(n1->getOutput(0)); + auto n3 = g->create(n2->getOutput(0)); + auto n4 = g->create(n2->getOutput(0)); + auto n5 = g->create(n1->getOutput(0)); g->replaceNode(n2, n5); @@ -40,7 +40,8 @@ TEST(NodeMutatorTest, SimpleChainTest) g->accept(&d); auto str = ss.str(); - ASSERT_TRUE(str == "iop1rop5rop3rop4" || str == "iop1rop5rop4rop3") << "str = " << str; + ASSERT_TRUE(str == "i0r4r2r3" || str == "i0r4r3r2") << "str = " << str; delete g; } -} + +} // namespace diff --git a/compiler/mir/unittests/ShapeInference.cpp b/compiler/mir/unittests/ShapeInference.cpp index 3a5459d..4d7017f 100644 --- a/compiler/mir/unittests/ShapeInference.cpp +++ b/compiler/mir/unittests/ShapeInference.cpp @@ -35,8 +35,8 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) Shape input_shape{10, 2, 5}; Shape expected_shape{10, 1, 10}; - auto input = g.create("input", input_shape); - auto op = g.create("reshape", input->getOutput(0), Shape{10, 1, Shape::autoDim}); + auto input = g.create(input_shape); + auto op = g.create(input->getOutput(0), Shape{10, 1, Shape::autoDim}); ASSERT_EQ(expected_shape, op->getOutputShape(0)); } @@ -47,9 +47,9 @@ TEST(ShapeInferenceTest, ResizeWithShape) Shape result_shape{2, 10, 10, 3}; - auto input = g.create("input", Shape{1, 5, 5, 3}); + auto input = g.create(Shape{1, 5, 5, 3}); - auto op = g.create("Resize", input->getOutput(0), + auto op = g.create(input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor, result_shape); ASSERT_EQ(result_shape, op->getOutputShape(0)); @@ -61,11 +61,11 @@ TEST(ShapeInferenceTest, ResizeWithScale) Shape result_shape{1, 30, 10, 3}; - auto input = g.create("input", Shape{1, 5, 5, 3}); + auto input = g.create(Shape{1, 5, 5, 3}); - auto op = g.create("Resize", input->getOutput(0), - ops::ResizeOp::ResizeMethod::nearestNeighbor, - std::vector{1, 6, 2, 1}); + auto op = + g.create(input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor, + std::vector{1, 6, 2, 1}); ASSERT_EQ(result_shape, op->getOutputShape(0)); } @@ -76,10 +76,9 @@ TEST(ShapeInferenceTest, ReduceChangeRank) Shape resultShape{10, 10}; - auto input = g.create("input", Shape{10, 2, 10, 9}); + auto input = g.create(Shape{10, 2, 10, 9}); - auto n = - g.create("reduce", input->getOutput(0), std::vector{1, 3}, false); + auto n = g.create(input->getOutput(0), std::vector{1, 3}, false); ASSERT_EQ(resultShape, n->getOutputShape(0)); } @@ -91,8 +90,8 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) Shape input_shape{10, 2, 10}; Shape result_shape_shrink{10, 20}; - auto input = g.create("input", input_shape); - auto op = g.create("reshape", input->getOutput(0), Shape{10, Shape::autoDim}); + auto input = g.create(input_shape); + auto op = g.create(input->getOutput(0), Shape{10, Shape::autoDim}); ASSERT_EQ(result_shape_shrink, op->getOutputShape(0)); } @@ -104,9 +103,8 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) Shape input_shape{10, 2, 10}; Shape result_shape_expand{5, 10, 2, 2}; - auto input = g.create("input", input_shape); - auto op = - g.create("reshape", input->getOutput(0), Shape{5, Shape::autoDim, 2, 2}); + auto input = g.create(input_shape); + auto op = g.create(input->getOutput(0), Shape{5, Shape::autoDim, 2, 2}); ASSERT_EQ(result_shape_expand, op->getOutputShape(0)); } @@ -118,9 +116,8 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionUnsqueeze) Shape input_shape{10, 2, 10}; Shape result_shape_expand{1, 10, 2, 1, 10, 1}; - auto input = g.create("input", input_shape); - auto op = g.create("reshape", input->getOutput(0), - Shape{1, Shape::autoDim, 2, 1, 10, 1}); + auto input = g.create(input_shape); + auto op = g.create(input->getOutput(0), Shape{1, Shape::autoDim, 2, 1, 10, 1}); ASSERT_EQ(result_shape_expand, op->getOutputShape(0)); } @@ -132,8 +129,8 @@ TEST(ShapeInferenceTest, SqueezeTestAllDims) Shape input_shape{1, 2, 1, 4}; Shape expected_shape{2, 4}; - auto input = g.create("input", input_shape); - auto sq1 = g.create("squeeze_1", input->getOutput(0), std::vector{}); + auto input = g.create(input_shape); + auto sq1 = g.create(input->getOutput(0), std::vector{}); ASSERT_EQ(sq1->getOutputShape(0), expected_shape); } @@ -144,10 +141,10 @@ TEST(ShapeInferenceTest, ElementwiseBC) Shape input_shape{1, 10, 10, 1}; Shape input2_shape{1, 1, 10, 10}; - auto input = g.create("input1", input_shape); - auto input2 = g.create("input2", input2_shape); + auto input = g.create(input_shape); + auto input2 = g.create(input2_shape); - auto add = g.create("add_1", input->getOutput(0), input2->getOutput(0)); + auto add = g.create(input->getOutput(0), input2->getOutput(0)); ASSERT_EQ(add->getOutputShape(0), Shape({1, 10, 10, 10})); } @@ -159,8 +156,8 @@ TEST(ShapeInferenceTest, SqueezeTestSpecificDims) Shape input_shape{1, 2, 1, 4}; Shape expected_shape{1, 2, 4}; - auto input = g.create("input", input_shape); - auto sq1 = g.create("squeeze_1", input->getOutput(0), std::vector{2}); + auto input = g.create(input_shape); + auto sq1 = g.create(input->getOutput(0), std::vector{2}); ASSERT_EQ(sq1->getOutputShape(0), expected_shape); } @@ -172,8 +169,8 @@ TEST(ShapeInferenceTest, SqueezeTestScalarResult) Shape input_shape{1, 1, 1, 1}; Shape expected_shape{1}; - auto input = g.create("input", input_shape); - auto sq1 = g.create("squeeze_1", input->getOutput(0), std::vector{}); + auto input = g.create(input_shape); + auto sq1 = g.create(input->getOutput(0), std::vector{}); ASSERT_EQ(sq1->getOutputShape(0), expected_shape); }