":fusion_merger",
":instruction_fusion",
"//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!c_all_of(fusion->users(), [](const HloInstruction* instruction) {
- return instruction->opcode() == HloOpcode::kFusion &&
- instruction->fusion_kind() == HloInstruction::FusionKind::kLoop;
+ if (!c_all_of(fusion->users(), [](const HloInstruction* user) {
+ return user->opcode() == HloOpcode::kFusion &&
+ (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
+ user->fusion_kind() == HloInstruction::FusionKind::kInput);
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Some of its users are not loop/input fusion kernels.";
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace gpu {
namespace {
+namespace op = xla::testing::opcode_matchers;
+
class FusionMergerTest : public HloTestBase {
protected:
FusionMergerTest() : module_(CreateNewModule()) {}
EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie());
}
+// Check that we're willing to merge f1_computation into f2_computation, even
+// though f2 is an input fusion node.
+TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
+ const char* const kModule = R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[10]{0} parameter(0)
+ ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[10]{0} parameter(0)
+ f2_mul = f32[10]{0} multiply(f2_p0, f2_p0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[10]{0} parameter(0)
+ f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ }
+ )";
+ auto module = tools::Parse(kModule).ValueOrDie();
+ EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Fusion(op::Parameter()));
+}
+
} // namespace
} // namespace gpu
} // namespace xla
// If no operand has a compatible shape, prefer an operand that has
// the same rank at least.
for (const HloInstruction* operand : operands) {
+ // Skip tuple-shaped operands; calling ShapeUtil::Rank on a
+ // tuple-shaped Shape is illegal. Perhaps more correct would be to
+ // recurse into them, but TODO(kramerb): Remove this code after
+ // assigning layouts to fusion nodes.
+ if (ShapeUtil::IsTuple(operand->shape())) {
+ continue;
+ }
if (ShapeUtil::Rank(*input_shape) ==
ShapeUtil::Rank(operand->shape())) {
// Do not use CopyLayoutBetweenShapes because input_shape and