[XLA:CPU] Teach the CPU layout assignment about dot dimension numbers
authorSanjoy Das <sanjoy@google.com>
Tue, 12 Dec 2017 04:34:08 +0000 (20:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 12 Dec 2017 04:38:19 +0000 (20:38 -0800)
There is no great need for this yet, but I noticed that the test cases were
broken (they were constructing dots with unset dimension numbers), and one thing
led to another.

PiperOrigin-RevId: 178713597

tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h

index 401cf50717959da95f48963c3c83b3036a80eb1b..5d37a41571bdd23ad3b724707012353e56ff235b 100644 (file)
@@ -61,8 +61,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
-  auto result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_shape, HloOpcode::kDot, dot_lhs, dot_rhs));
+  auto result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
 
   auto module = CreateNewModule();
   HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -98,10 +98,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
       HloInstruction::CreateParameter(1, lhs_shape, "param1"));
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
-  auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs));
-  auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs));
+  auto dot_a_result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+  auto dot_b_result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
   builder.AddInstruction(HloInstruction::CreateBinary(
       result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
 
@@ -142,10 +142,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
       HloInstruction::CreateParameter(1, lhs_b_shape, "param1"));
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
-  auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_a_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs));
-  auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_b_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs));
+  auto dot_a_result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+  auto dot_b_result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
   auto tuple_result = builder.AddInstruction(
       HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
 
@@ -180,8 +180,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
       HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape)));
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateParameter(0, rhs_shape, "param0"));
-  auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_shape, HloOpcode::kDot, dot_lhs, dot_rhs));
+  auto dot_result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
 
   auto module = CreateNewModule();
   HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -220,8 +220,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
-  auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary(
-      result_shape, HloOpcode::kDot, dot_lhs, dot_rhs));
+  auto dot_result = builder.AddInstruction(
+      HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
 
   auto module = CreateNewModule();
   HloComputation* computation = module->AddEntryComputation(builder.Build());
index 7f0bf2c8e4e26511e2e69121042540120c281c62..296e018c6fc46d6113f37654eff16415ef9bb598 100644 (file)
@@ -1048,7 +1048,8 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
 // column major.
 bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo) {
   return hlo.opcode() == HloOpcode::kDot &&
-         hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1;
+         hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1 &&
+         hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0;
 }
 
 bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
index 784930195796220646e80cc1cd7a1b342083acfc..10ac665083e21c95caa38d1c6d2a87bc7f4da9d1 100644 (file)
@@ -347,6 +347,20 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
   return instruction;
 }
 
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
+    const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
+  CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+  CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+
+  auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+  instruction->AppendOperand(lhs);
+  instruction->AppendOperand(rhs);
+  instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
+  instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
+  instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
+  return instruction;
+}
+
 /* static */ std::unique_ptr<HloInstruction>
 HloInstruction::CreateReducePrecision(const Shape& shape,
                                       HloInstruction* operand,
index 03cf9aaf907e7437596b9cc1f093fd79d22963b9..092105582e09889091b90eae522489b3732f199c 100644 (file)
@@ -166,6 +166,12 @@ class HloInstruction {
       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
       const DotDimensionNumbers& dimension_numbers);
 
+  // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
+  // of the LHS with dimension 0 of the RHS with no batch dimensions.  Both LHS
+  // and the RHS must be of rank 2.
+  static std::unique_ptr<HloInstruction> CreateCanonicalDot(
+      const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
+
   // Creates a reduce-precision op, where operand is the data to reduce in
   // precision, and exponent_bits and mantissa_bits describe the precision to
   // reduce it to.