Names on operations were difficult to maintain.
Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
virtual ~Graph();
- // TODO Remove after eliminating all uses.
- template <typename T, typename U, typename... Args>
- typename std::enable_if<std::is_convertible<U, std::string>::value, Operation *>::type
- create(U &&name, Args &&... args)
- {
- auto op = new T(std::forward<Args>(args)...);
- op->setId(_last_node_id++);
- op->setName(std::forward<U>(name));
- registerOp(op);
- return op;
- }
-
template <typename T, typename... Args> Operation *create(Args &&... args)
{
auto op = new T(std::forward<Args>(args)...);
assert(inputs.size() == old_op->getNumInputs());
auto op = old_op->copyWithInputs(inputs);
op->setId(_last_node_id++);
- op->setName(old_op->getName());
registerOp(op);
return op;
}
std::size_t getId() const { return _id; }
void setId(std::size_t id) { _id = id; }
- const std::string &getName() const { return _name; }
- void setName(const std::string &name) { _name = name; }
-
std::size_t getNumInputs() const { return _inputs.size(); }
std::size_t getNumOutputs() const { return _outputs.size(); }
private:
Type _type;
std::size_t _id = std::numeric_limits<std::size_t>::max();
- std::string _name;
std::deque<Input> _inputs;
std::deque<Output> _outputs;
};
DotIrNodeInfo() = default;
- DotIrNodeInfo &withType(const std::string &type_name, const std::string &node_name);
+ DotIrNodeInfo &withType(const std::string &type_name);
DotIrNodeInfo &withInShapes(Shapes &&in_shapes);
DotIrNodeInfo &withOutShapes(Shapes &&out_shapes);
void addPipeIfNeeded(std::stringstream &ss, bool needed, bool &need_pipe) const;
std::string _type_name;
- std::string _node_name;
Shapes _in_shapes;
Shapes _out_shapes;
assert(op->getNumOutputs() == 1 &&
"Only operations with single output value can be replaced with input node");
- auto in = create<ops::InputOp>(op->getName(), op->getOutputShape(0));
+ auto in = create<ops::InputOp>(op->getOutputShape(0));
replaceNode(op, in);
return dynamic_cast<ops::InputOp *>(in);
for (auto &op : _ops)
{
- if (new_input_set.count(op->getName()) != 0)
+ if (op->getNumOutputs() == 1 && new_input_set.count(op->getOutput(0)->getName()) != 0)
{
ops_to_replace.push_back(op);
}
void IrDotDumper::visit(ops::AvgPool2DOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("AvgPool2D", op.getName())
+ .withType("AvgPool2D")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withShape("Window size", Shape(op.getWindowSize()))
void IrDotDumper::visit(ops::CappedReluOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("CappedRelu", op.getName())
+ .withType("CappedRelu")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withMisc("Cap", op.getCap());
void IrDotDumper::visit(ops::ConcatOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("Concat", op.getName())
+ .withType("Concat")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withMisc("Axis", op.getAxis());
void IrDotDumper::visit(ops::Conv2DOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("Conv2D", op.getName())
+ .withType("Conv2D")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withStride(op.getStrides())
void IrDotDumper::visit(ops::DepthwiseConv2DOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("DepthwiseConv2D", op.getName())
+ .withType("DepthwiseConv2D")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withStride(op.getStrides())
void IrDotDumper::visit(ops::FullyConnectedOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("FullyConnected", op.getName())
+ .withType("FullyConnected")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::MaxPool2DOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("MaxPool2D", op.getName())
+ .withType("MaxPool2D")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withShape("Window size", Shape(op.getWindowSize()))
void IrDotDumper::visit(ops::SoftmaxOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("Softmax", op.getName())
+ .withType("Softmax")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withMisc("Axis", op.getAxis());
void IrDotDumper::visit(ops::PoolOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("Pool2D", op.getName())
+ .withType("Pool2D")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withShape("PoolWindow", op.getWindowShape())
void IrDotDumper::visit(ops::ReluOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("ReLU", op.getName())
+ .withType("ReLU")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::ReshapeOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("Reshape", op.getName())
+ .withType("Reshape")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::InputOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("Input", op.getName())
+ .withType("Input")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::ConstantOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("Constant", op.getName())
+ .withType("Constant")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::SliceOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("SliceOp", op.getName())
+ .withType("SliceOp")
.withInShapes(getInputShapes(op))
.withShape("Starts", op.getStarts())
.withShape("Sizes", op.getSizes())
void IrDotDumper::visit(ops::DeConv2DOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("DeConv2D", op.getName())
+ .withType("DeConv2D")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withPadType(op.getPaddingType())
void IrDotDumper::visit(ops::EluOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("EluOp", op.getName())
+ .withType("EluOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withMisc("Alpha", op.getAlpha());
void IrDotDumper::visit(ops::TanhOp &op)
{
auto nodeInfo = DotIrNodeInfo()
- .withType("TanhOp", op.getName())
+ .withType("TanhOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::SqueezeOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("SqueezeOp", op.getName())
+ .withType("SqueezeOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void mir::IrDotDumper::visit(ops::PadOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("PadOp", op.getName())
+ .withType("PadOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::SqrtOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("Sqrt", op.getName())
+ .withType("Sqrt")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
auto node_info =
DotIrNodeInfo()
- .withType("ReduceOp", op.getName())
+ .withType("ReduceOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims
{
auto node_info =
DotIrNodeInfo()
- .withType("ReduceMeanOp", op.getName())
+ .withType("ReduceMeanOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims
{ops::ResizeOp::ResizeMethod::nearestNeighbor, "nearestNeighbor"}};
auto node_info = DotIrNodeInfo()
- .withType("Resize", op.getName())
+ .withType("Resize")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withMisc("Mode", modes.at(op.getMode()));
void IrDotDumper::visit(ops::TransposeOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("TransposeOp", op.getName())
+ .withType("TransposeOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::GatherOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("GatherOp", op.getName())
+ .withType("GatherOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(ops::SigmoidOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("SigmoidOp", op.getName())
+ .withType("SigmoidOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op));
void IrDotDumper::visit(mir::ops::LeakyReluOp &op)
{
auto node_info = DotIrNodeInfo()
- .withType("LeakyReluOp", op.getName())
+ .withType("LeakyReluOp")
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withMisc("alpha", op.getAlpha());
void IrDotDumper::visit(ops::OutputOp &op)
{
- auto node_info =
- DotIrNodeInfo().withType("OutputOp", op.getName()).withInShapes(getInputShapes(op));
+ auto node_info = DotIrNodeInfo().withType("OutputOp").withInShapes(getInputShapes(op));
_dot_builder.updateWithOp(&op, node_info);
}
template <> DotIrNodeInfo::Stringable::Stringable(const char *val) : _val(val) {}
-DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &type_name, const std::string &node_name)
+DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &type_name)
{
this->_type_name = type_name;
- this->_node_name = node_name;
return *this;
}
{
std::stringstream ss;
- ss << "{";
-
- // Node type and name
- ss << (!_type_name.empty() ? _type_name : "TYPE_NOT_SET") << ": "
- << (!_node_name.empty() ? _node_name : "NAME_NOT_SET");
-
- if (_type_name.empty())
- {
- std::cout << "Warning: Node type is not set for "
- << (_node_name.empty() ? "one of the nodes" : "node " + _node_name) << std::endl;
- }
-
- ss << " | ";
+ ss << "{" << _type_name << " | ";
// Note inputs and output shapes
ss << "{{";
public:
DumpVisitor(std::ostream &s) : _s(s) {}
- void visit(ops::InputOp &op) override { _s << "i" << op.getName(); };
+ void visit(ops::InputOp &op) override { _s << "i" << std::to_string(op.getId()); };
- void visit(ops::ReluOp &op) override { _s << "r" << op.getName(); }
+ void visit(ops::ReluOp &op) override { _s << "r" << std::to_string(op.getId()); }
- void visit(ops::ConcatOp &op) override { _s << "c" << op.getName(); }
+ void visit(ops::ConcatOp &op) override { _s << "c" << std::to_string(op.getId()); }
std::ostream &_s;
};
{
auto g = new Graph;
- auto n1 = g->create<ops::InputOp>("op1", Shape{1});
- auto n2 = g->create<ops::ReluOp>("op2", n1->getOutput(0));
- auto n3 = g->create<ops::ReluOp>("op3", n2->getOutput(0));
- auto n4 = g->create<ops::ReluOp>("op4", n2->getOutput(0));
+ auto n1 = g->create<ops::InputOp>(Shape{1});
+ auto n2 = g->create<ops::ReluOp>(n1->getOutput(0));
+ auto n3 = g->create<ops::ReluOp>(n2->getOutput(0));
+ auto n4 = g->create<ops::ReluOp>(n2->getOutput(0));
std::vector<Operation::Output *> concat_inputs{n3->getOutput(0), n4->getOutput(0)};
- auto n5 = g->create<ops::ConcatOp>("op5", concat_inputs, 0);
+ auto n5 = g->create<ops::ConcatOp>(concat_inputs, 0);
+ n1->getOutput(0)->setName("op1");
+ n4->getOutput(0)->setName("op4");
g->replaceInputNodes({"op1", "op4"});
g->accept(&d);
auto str = ss.str();
- ASSERT_TRUE(str == "iop1iop4rop2rop3cop5" || str == "iop4iop1rop2rop3cop5") << "str = " << str;
+ ASSERT_TRUE(str == "i5i6r1r2c4" || str == "i6i5r1r2c4") << "str = " << str;
delete g;
};
{
auto g = new Graph;
- auto n1 = g->create<ops::InputOp>("op1", Shape{});
- auto n2 = g->create<ops::ReluOp>("op2", n1->getOutput(0));
- auto n3 = g->create<ops::OutputOp>("op3", n2->getOutput(0));
+ auto n1 = g->create<ops::InputOp>(Shape{});
+ auto n2 = g->create<ops::ReluOp>(n1->getOutput(0));
+ auto n3 = g->create<ops::OutputOp>(n2->getOutput(0));
auto in2 = g->replaceWithInputNode(n2);
ASSERT_EQ(g->getInputs(), expected_inputs);
delete g;
}
-}
+
+} // namespace
public:
DumpVisitor(std::ostream &s) : _s(s) {}
- void visit(ops::InputOp &op) override { _s << "i" << op.getName(); };
+ void visit(ops::InputOp &op) override { _s << "i" << std::to_string(op.getId()); };
- void visit(ops::ReluOp &op) override { _s << "r" << op.getName(); }
+ void visit(ops::ReluOp &op) override { _s << "r" << std::to_string(op.getId()); }
- void visit(ops::ConcatOp &op) override { _s << "c" << op.getName(); }
+ void visit(ops::ConcatOp &op) override { _s << "c" << std::to_string(op.getId()); }
std::ostream &_s;
};
TEST(NodeMutatorTest, SimpleChainTest)
{
auto g = new Graph;
- auto n1 = g->create<ops::InputOp>("op1", Shape{});
- auto n2 = g->create<ops::ReluOp>("op2", n1->getOutput(0));
- auto n3 = g->create<ops::ReluOp>("op3", n2->getOutput(0));
- auto n4 = g->create<ops::ReluOp>("op4", n2->getOutput(0));
- auto n5 = g->create<ops::ReluOp>("op5", n1->getOutput(0));
+ auto n1 = g->create<ops::InputOp>(Shape{});
+ auto n2 = g->create<ops::ReluOp>(n1->getOutput(0));
+ auto n3 = g->create<ops::ReluOp>(n2->getOutput(0));
+ auto n4 = g->create<ops::ReluOp>(n2->getOutput(0));
+ auto n5 = g->create<ops::ReluOp>(n1->getOutput(0));
g->replaceNode(n2, n5);
g->accept(&d);
auto str = ss.str();
- ASSERT_TRUE(str == "iop1rop5rop3rop4" || str == "iop1rop5rop4rop3") << "str = " << str;
+ ASSERT_TRUE(str == "i0r4r2r3" || str == "i0r4r3r2") << "str = " << str;
delete g;
}
-}
+
+} // namespace
Shape input_shape{10, 2, 5};
Shape expected_shape{10, 1, 10};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, 1, Shape::autoDim});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto op = g.create<ops::ReshapeOp>(input->getOutput(0), Shape{10, 1, Shape::autoDim});
ASSERT_EQ(expected_shape, op->getOutputShape(0));
}
Shape result_shape{2, 10, 10, 3};
- auto input = g.create<ops::InputOp>("input", Shape{1, 5, 5, 3});
+ auto input = g.create<ops::InputOp>(Shape{1, 5, 5, 3});
- auto op = g.create<ops::ResizeOp>("Resize", input->getOutput(0),
+ auto op = g.create<ops::ResizeOp>(input->getOutput(0),
ops::ResizeOp::ResizeMethod::nearestNeighbor, result_shape);
ASSERT_EQ(result_shape, op->getOutputShape(0));
Shape result_shape{1, 30, 10, 3};
- auto input = g.create<ops::InputOp>("input", Shape{1, 5, 5, 3});
+ auto input = g.create<ops::InputOp>(Shape{1, 5, 5, 3});
- auto op = g.create<ops::ResizeOp>("Resize", input->getOutput(0),
- ops::ResizeOp::ResizeMethod::nearestNeighbor,
- std::vector<float>{1, 6, 2, 1});
+ auto op =
+ g.create<ops::ResizeOp>(input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor,
+ std::vector<float>{1, 6, 2, 1});
ASSERT_EQ(result_shape, op->getOutputShape(0));
}
Shape resultShape{10, 10};
- auto input = g.create<ops::InputOp>("input", Shape{10, 2, 10, 9});
+ auto input = g.create<ops::InputOp>(Shape{10, 2, 10, 9});
- auto n =
- g.create<ops::ReduceMeanOp>("reduce", input->getOutput(0), std::vector<int32_t>{1, 3}, false);
+ auto n = g.create<ops::ReduceMeanOp>(input->getOutput(0), std::vector<int32_t>{1, 3}, false);
ASSERT_EQ(resultShape, n->getOutputShape(0));
}
Shape input_shape{10, 2, 10};
Shape result_shape_shrink{10, 20};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, Shape::autoDim});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto op = g.create<ops::ReshapeOp>(input->getOutput(0), Shape{10, Shape::autoDim});
ASSERT_EQ(result_shape_shrink, op->getOutputShape(0));
}
Shape input_shape{10, 2, 10};
Shape result_shape_expand{5, 10, 2, 2};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto op =
- g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{5, Shape::autoDim, 2, 2});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto op = g.create<ops::ReshapeOp>(input->getOutput(0), Shape{5, Shape::autoDim, 2, 2});
ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
}
Shape input_shape{10, 2, 10};
Shape result_shape_expand{1, 10, 2, 1, 10, 1};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0),
- Shape{1, Shape::autoDim, 2, 1, 10, 1});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto op = g.create<ops::ReshapeOp>(input->getOutput(0), Shape{1, Shape::autoDim, 2, 1, 10, 1});
ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
}
Shape input_shape{1, 2, 1, 4};
Shape expected_shape{2, 4};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", input->getOutput(0), std::vector<int32_t>{});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto sq1 = g.create<ops::SqueezeOp>(input->getOutput(0), std::vector<int32_t>{});
ASSERT_EQ(sq1->getOutputShape(0), expected_shape);
}
Shape input_shape{1, 10, 10, 1};
Shape input2_shape{1, 1, 10, 10};
- auto input = g.create<ops::InputOp>("input1", input_shape);
- auto input2 = g.create<ops::InputOp>("input2", input2_shape);
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto input2 = g.create<ops::InputOp>(input2_shape);
- auto add = g.create<ops::AddOp>("add_1", input->getOutput(0), input2->getOutput(0));
+ auto add = g.create<ops::AddOp>(input->getOutput(0), input2->getOutput(0));
ASSERT_EQ(add->getOutputShape(0), Shape({1, 10, 10, 10}));
}
Shape input_shape{1, 2, 1, 4};
Shape expected_shape{1, 2, 4};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", input->getOutput(0), std::vector<int32_t>{2});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto sq1 = g.create<ops::SqueezeOp>(input->getOutput(0), std::vector<int32_t>{2});
ASSERT_EQ(sq1->getOutputShape(0), expected_shape);
}
Shape input_shape{1, 1, 1, 1};
Shape expected_shape{1};
- auto input = g.create<ops::InputOp>("input", input_shape);
- auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", input->getOutput(0), std::vector<int32_t>{});
+ auto input = g.create<ops::InputOp>(input_shape);
+ auto sq1 = g.create<ops::SqueezeOp>(input->getOutput(0), std::vector<int32_t>{});
ASSERT_EQ(sq1->getOutputShape(0), expected_shape);
}