[XLA] Fix handling of CustomCall's window and dnums.
authorJustin Lebar <jlebar@google.com>
Fri, 1 Jun 2018 00:19:25 +0000 (17:19 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 1 Jun 2018 00:25:30 +0000 (17:25 -0700)
CustomCall can have a window and convolution-dimension-numbers, so
HloInstruction needs to handle this in Clone() and Identical().

PiperOrigin-RevId: 198805211

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction_test.cc

index aa41631..2b14b63 100644 (file)
@@ -426,6 +426,7 @@ tf_cc_test(
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/compiler/xla/tools/parser:hlo_parser",
index a68075e..4095b3d 100644 (file)
@@ -1330,6 +1330,14 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
       break;
     case HloOpcode::kCustomCall:
       clone = CreateCustomCall(shape, new_operands, custom_call_target_);
+      if (window_ != nullptr) {
+        clone->window_ = MakeUnique<Window>(*window_);
+      }
+      if (convolution_dimension_numbers_ != nullptr) {
+        clone->convolution_dimension_numbers_ =
+            MakeUnique<ConvolutionDimensionNumbers>(
+                *convolution_dimension_numbers_);
+      }
       break;
     case HloOpcode::kHostCompute:
       clone = CreateHostCompute(shape, new_operands, channel_name_,
@@ -1882,6 +1890,19 @@ bool HloInstruction::IdenticalSlowPath(
     case HloOpcode::kMap:
       return eq_computations(to_apply(), other.to_apply());
     case HloOpcode::kCustomCall:
+      if ((window_ == nullptr) != (other.window_ == nullptr) ||
+          (window_ != nullptr &&
+           !protobuf_util::ProtobufEquals(window(), other.window()))) {
+        return false;
+      }
+      if ((convolution_dimension_numbers_ == nullptr) !=
+              (other.convolution_dimension_numbers_ == nullptr) ||
+          (convolution_dimension_numbers_ != nullptr &&
+           !protobuf_util::ProtobufEquals(
+               convolution_dimension_numbers(),
+               other.convolution_dimension_numbers()))) {
+        return false;
+      }
       return custom_call_target_ == other.custom_call_target_;
     case HloOpcode::kReverse:
       return dimensions() == other.dimensions();
index d1b6bc7..a1a8814 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
 
 namespace xla {
 namespace {
@@ -1558,5 +1559,54 @@ TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
   EXPECT_FALSE(add1->Identical(*add2));
 }
 
+TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
+  auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+                                                 /*operands=*/{},
+                                                 /*custom_call_target=*/"foo");
+  auto instr2 = instr1->Clone();
+  EXPECT_TRUE(instr1->Identical(*instr2));
+
+  Window w = window_util::MakeWindow({1, 2, 3});
+  instr1->set_window(w);
+  EXPECT_FALSE(instr1->Identical(*instr2));
+}
+
+TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
+  auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+                                                 /*operands=*/{},
+                                                 /*custom_call_target=*/"foo");
+  auto instr2 = instr1->Clone();
+  EXPECT_TRUE(instr1->Identical(*instr2));
+
+  ConvolutionDimensionNumbers dnums;
+  dnums.set_output_batch_dimension(42);
+  instr1->set_convolution_dimension_numbers(dnums);
+  EXPECT_FALSE(instr1->Identical(*instr2));
+}
+
+TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
+  auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+                                                /*operands=*/{},
+                                                /*custom_call_target=*/"foo");
+  Window w = window_util::MakeWindow({1, 2, 3});
+  instr->set_window(w);
+  auto clone = instr->Clone();
+  EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
+      << clone->window().DebugString();
+}
+
+TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
+  auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+                                                /*operands=*/{},
+                                                /*custom_call_target=*/"foo");
+  ConvolutionDimensionNumbers dnums;
+  dnums.set_output_batch_dimension(42);
+  instr->set_convolution_dimension_numbers(dnums);
+  auto clone = instr->Clone();
+  EXPECT_TRUE(protobuf_util::ProtobufEquals(
+      clone->convolution_dimension_numbers(), dnums))
+      << clone->convolution_dimension_numbers().DebugString();
+}
+
 }  // namespace
 }  // namespace xla