From: Justin Lebar Date: Fri, 1 Jun 2018 00:19:25 +0000 (-0700) Subject: [XLA] Fix handling of CustomCall's window and dnums. X-Git-Tag: upstream/v1.9.0_rc1~26^2~6^2~3 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c3b62c38ebd73c98ffa5613865f4c01fa5ff6ae7;p=platform%2Fupstream%2Ftensorflow.git [XLA] Fix handling of CustomCall's window and dnums. CustomCall can have a window and convolution-dimension-numbers, so HloInstruction needs to handle this in Clone() and Identical(). PiperOrigin-RevId: 198805211 --- diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aa41631..2b14b63 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -426,6 +426,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tools/parser:hlo_parser", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a68075e..4095b3d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1330,6 +1330,14 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); + if (window_ != nullptr) { + clone->window_ = MakeUnique(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + clone->convolution_dimension_numbers_ = + MakeUnique( + *convolution_dimension_numbers_); + } break; case HloOpcode::kHostCompute: clone = CreateHostCompute(shape, new_operands, channel_name_, @@ -1882,6 +1890,19 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: + if ((window_ == nullptr) != (other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(window(), other.window()))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()))) { + return false; + } return custom_call_target_ == other.custom_call_target_; case HloOpcode::kReverse: return dimensions() == other.dimensions(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d1b6bc7..a1a8814 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -1558,5 +1559,54 @@ TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { EXPECT_FALSE(add1->Identical(*add2)); } +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); +} + } // namespace } // namespace xla