HloInstruction::CreateParameter(0, lhs_shape, "param0"));
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
- auto result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_shape, HloOpcode::kDot, dot_lhs, dot_rhs));
+ auto result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloInstruction::CreateParameter(1, lhs_shape, "param1"));
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
- auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs));
- auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs));
+ auto dot_a_result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+ auto dot_b_result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
builder.AddInstruction(HloInstruction::CreateBinary(
result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
HloInstruction::CreateParameter(1, lhs_b_shape, "param1"));
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
- auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_a_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs));
- auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_b_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs));
+ auto dot_a_result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+ auto dot_b_result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
auto tuple_result = builder.AddInstruction(
HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape)));
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, rhs_shape, "param0"));
- auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_shape, HloOpcode::kDot, dot_lhs, dot_rhs));
+ auto dot_result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
- auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary(
- result_shape, HloOpcode::kDot, dot_lhs, dot_rhs));
+ auto dot_result = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
+ CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+ CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+ instruction->AppendOperand(lhs);
+ instruction->AppendOperand(rhs);
+ instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
+ instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
+ instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,