return true;
}
-void HloCustomCallMatcher::DescribeTo(std::ostream* os) const {
- HloMatcher::DescribeTo(os);
- *os << " with call target that "
- << ::testing::DescribeMatcher<string>(call_target_matcher_);
-}
-
-bool HloCustomCallMatcher::MatchAndExplain(
- const HloInstruction* instruction,
- ::testing::MatchResultListener* listener) const {
- if (!HloMatcher::MatchAndExplain(instruction, listener)) {
- return false;
- }
- ::testing::StringMatchResultListener sub_listener;
- bool result = ExplainMatchResult(
- call_target_matcher_, instruction->custom_call_target(), &sub_listener);
- if (sub_listener.str().empty()) {
- sub_listener << " that "
- << ::testing::DescribeMatcher<string>(call_target_matcher_,
- /*negation=*/!result);
- }
- *listener << "custom-call with call target" << sub_listener.str();
- return result;
-}
-
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
// index to match.
class HloGetTupleElementMatcher : public HloMatcher {
public:
- HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
- int64 tuple_index)
+ explicit HloGetTupleElementMatcher(
+ ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index)
: HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
tuple_index_(tuple_index) {}
int64 tuple_index_;
};
-// Custom matcher for custom-call instructions, which accepts a matcher for its
-// call target.
-class HloCustomCallMatcher : public HloMatcher {
- public:
- HloCustomCallMatcher(
- ::testing::Matcher<string> call_target_matcher,
- std::vector<::testing::Matcher<const HloInstruction*>> operands)
- : HloMatcher(HloOpcode::kCustomCall, operands),
- call_target_matcher_(call_target_matcher) {}
-
- bool MatchAndExplain(const HloInstruction* instruction,
- ::testing::MatchResultListener* listener) const override;
- void DescribeTo(std::ostream* os) const override;
-
- private:
- ::testing::Matcher<string> call_target_matcher_;
-};
-
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
HLO_MATCHER(Convolution);
HLO_MATCHER(Copy);
HLO_MATCHER(CrossReplicaSum);
+HLO_MATCHER(CustomCall);
HLO_MATCHER(Divide);
HLO_MATCHER(Dot);
HLO_MATCHER(DynamicSlice);
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
}
-// - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
-// target T and the given operands.
-//
-// - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
-// given operands.
-//
-// - CustomCall() matches any CustomCall HLO at all.
-template <typename... M>
-inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
- ::testing::Matcher<string> call_target_matcher, M... operands) {
- return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
- call_target_matcher, {operands...}));
-}
-// This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
-// ::testing::Matcher<string>. In that case, we want to prefer the overload
-// above.
-template <typename FirstM, typename... M,
- typename Dummy = typename std::enable_if<
- !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
- void>::type*>
-inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
- FirstM operands_first, M... operands_rest) {
- return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
- HloOpcode::kCustomCall, {operands_first, operands_rest...}));
-}
-inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
- return ::testing::MakeMatcher(
- new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
-}
-
#undef HLO_MATCHER
} // namespace opcode_matchers
namespace xla {
namespace {
-string DescribeHloMatcher(const ::testing::Matcher<const HloInstruction*>& m) {
- std::stringstream ss;
- m.DescribeTo(&ss);
- return ss.str();
-}
-
template <typename M, typename T>
string Explain(const T& t, const M& m) {
::testing::StringMatchResultListener listener;
"add"));
}
-TEST(HloMatchersTest, CustomCallMatcher) {
- auto c1 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
- auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}));
- auto call = HloInstruction::CreateCustomCall(
- ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target");
-
- EXPECT_THAT(call.get(), op::CustomCall());
- EXPECT_THAT(call.get(), op::CustomCall(c1.get(), c2.get()));
- EXPECT_THAT(call.get(), op::CustomCall("foo_target"));
- EXPECT_THAT(call.get(), op::CustomCall("foo_target", c1.get(), c2.get()));
- EXPECT_THAT(call.get(), op::CustomCall(::testing::StartsWith("foo")));
- EXPECT_THAT(call.get(),
- op::CustomCall(::testing::Not(::testing::StartsWith("bar"))));
-
- // Wrong number of operands.
- EXPECT_THAT(call.get(), ::testing::Not(op::CustomCall(c1.get())));
-
- // Call target does not match.
- EXPECT_THAT(call.get(),
- ::testing::Not(op::CustomCall(::testing::StartsWith("bar"))));
-
- EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")),
- R"(custom-call with call target that isn't equal to "bar")");
- EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")),
- R"(custom-call with call target that is equal to "foo_target")");
-}
-
} // namespace
} // namespace xla