From fb6d927a06a1cff15a71f6b47c207fafbaad6a57 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 7 May 2018 12:10:15 -0700 Subject: [PATCH] [XLA] Add FusionKind matcher to pattern_matcher.h. PiperOrigin-RevId: 195700319 --- tensorflow/compiler/xla/service/pattern_matcher.h | 34 ++++++++++++++++++++++ .../compiler/xla/service/pattern_matcher_test.cc | 23 +++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 586f6ef..d3bc47e 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -702,6 +702,30 @@ class HloInstructionPatternOperandImpl { HloInstructionPattern operand_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a fusion node with a particular kind. +template +class HloInstructionPatternFusionKindImpl { + public: + explicit constexpr HloInstructionPatternFusionKindImpl( + const Previous& previous, ::xla::HloInstruction::FusionKind kind) + : previous_(previous), kind_(kind) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + private: + Previous previous_; + ::xla::HloInstruction::FusionKind kind_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -807,6 +831,16 @@ class HloInstructionPattern { matched_inst_); } + // Modifies the pattern to match only if the instruction is a fusion node with + // the given kind. + constexpr HloInstructionPattern> + WithFusionKind(HloInstruction::FusionKind kind) const { + return HloInstructionPattern>( + HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index c88157c..204e8c9 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -170,5 +170,28 @@ TEST(PatternMatcherTest, TupleShape) { Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); } +TEST(PatternMatcherTest, FusionKind) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + fused_computation { + ROOT fp0 = f32[] parameter(0) + } + + ENTRY while.v11 { + p0 = f32[] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop))); + EXPECT_FALSE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput))); + EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind( + HloInstruction::FusionKind::kLoop))); +} + } // namespace } // namespace xla -- 2.7.4