Operation *head2 = g.create<ops::ReluOp>(input->getOutput(0));
Operation *tail1 = g.create<ops::ReluOp>(head1->getOutput(0));
Operation *tail2 = g.create<ops::ReluOp>(head2->getOutput(0));
- std::vector<mir::Operation::Output *> concat_inputs{tail1->getOutput(0), tail2->getOutput(0)};
+ vector<mir::Operation::Output *> concat_inputs{tail1->getOutput(0), tail2->getOutput(0)};
Operation *join = g.create<ops::ConcatOp>(concat_inputs, 0);
input->getOutput(0)->setName("input");
head1->getOutput(0)->setName("head1");
ma.analyze(&g);
const auto &seq = ma.getInferenceSequence();
ASSERT_EQ(seq.size(), 6u);
- auto it = seq.begin();
- ASSERT_EQ(getCall(*(it++))->mirOp, input);
- ASSERT_EQ(getCall(*(it++))->mirOp, head1);
- ASSERT_EQ(getCall(*(it++))->mirOp, tail1);
- ASSERT_EQ(getCall(*(it++))->mirOp, head2);
- ASSERT_EQ(getCall(*(it++))->mirOp, tail2);
- ASSERT_EQ(getCall(*(it++))->mirOp, join);
+
+ vector<Operation *> op_seq(seq.size());
+ transform(seq.cbegin(), seq.cend(), op_seq.begin(),
+ [](const unique_ptr<sir::Action> &action) { return getCall(action)->mirOp; });
+
+ vector<Operation *> valid_seq1{input, head1, tail1, head2, tail2, join};
+ vector<Operation *> valid_seq2{input, head2, tail2, head1, tail1, join};
+ ASSERT_TRUE(op_seq == valid_seq1 || op_seq == valid_seq2);
}