namespace xla {
namespace cpu {
-Status CpuLayoutAssignment::AddBackendConstraints(
- LayoutConstraints* constraints) {
- auto row_major_shape = [](const Shape& old_shape) {
- Shape new_shape(old_shape);
- std::vector<int64> dimension_order(new_shape.dimensions_size());
- std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
- *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
- return new_shape;
- };
- auto col_major_shape = [](const Shape& old_shape) {
- Shape new_shape(old_shape);
- std::vector<int64> dimension_order(new_shape.dimensions_size());
- std::iota(dimension_order.begin(), dimension_order.end(), 0);
- *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
- return new_shape;
- };
-
- // We want to change the layout of constant arrays to be column major when all
- // of their users are dot operations that can be made faster with the flipped
- // layout. To avoid going quadriatic over the # of instructions, we cache
- // this property in should_make_rhs_col_major -- it maps a constant to true if
- // all of the users of said constant are dot operations that can be sped up.
- // This cache is populated lazily as we encounter dot operations traversing
- // the instruction stream.
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>
- should_make_rhs_col_major_cache;
- auto should_make_rhs_col_major = [&](const HloInstruction& instruction) {
- if (!ProfitableToMakeDotRhsColumnMajor(instruction)) {
- return false;
- }
+// We want to change the layout of constant arrays to be column major when all
+// of their users are dot operations that can be made faster with the flipped
+// layout. To avoid going quadriatic over the # of instructions, we cache this
+// property in should_make_rhs_col_major -- it maps a constant to true if all of
+// the users of said constant are dot operations that can be sped up. This
+// cache is populated lazily as we encounter dot operations traversing the
+// instruction stream.
+
+namespace {
+using ShouldMakeRhsColMajorCache =
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
+}
- const auto* rhs = instruction.operand(1);
- if (rhs->opcode() != HloOpcode::kConstant) {
- return false;
- }
+static bool ShouldMakeRhsColMajor(ShouldMakeRhsColMajorCache* cache,
+ const HloInstruction& instruction) {
+ if (!ProfitableToMakeDotRhsColumnMajor(instruction)) {
+ return false;
+ }
- auto it = should_make_rhs_col_major_cache.find(rhs);
- if (it != should_make_rhs_col_major_cache.end()) {
- return it->second;
- }
+ const auto* rhs = instruction.operand(1);
+ if (rhs->opcode() != HloOpcode::kConstant) {
+ return false;
+ }
+
+ auto it = cache->find(rhs);
+ if (it != cache->end()) {
+ return it->second;
+ }
+
+ bool result = std::all_of(rhs->users().begin(), rhs->users().end(),
+ [&](HloInstruction* user) {
+ return ProfitableToMakeDotRhsColumnMajor(*user) &&
+ user->operand(0) != rhs;
+ });
- bool result = std::all_of(
- rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) {
- return ProfitableToMakeDotRhsColumnMajor(*user) &&
- user->operand(0) != rhs;
- });
+ InsertOrDie(cache, rhs, result);
+ return result;
+}
+
+static Shape RowMajorShape(const Shape& old_shape) {
+ Shape new_shape(old_shape);
+ std::vector<int64> dimension_order(new_shape.dimensions_size());
+ std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
+ *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
+ return new_shape;
+}
- InsertOrDie(&should_make_rhs_col_major_cache, rhs, result);
- return result;
- };
+static Shape ColMajorShape(const Shape& old_shape) {
+ Shape new_shape(old_shape);
+ std::vector<int64> dimension_order(new_shape.dimensions_size());
+ std::iota(dimension_order.begin(), dimension_order.end(), 0);
+ *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
+ return new_shape;
+}
+
+Status CpuLayoutAssignment::AddBackendConstraints(
+ LayoutConstraints* constraints) {
+ ShouldMakeRhsColMajorCache cache;
const HloComputation* computation = constraints->computation();
for (auto* instruction : computation->instructions()) {
//
// These constraints are not hard constraints. Ideally, we should decide
// which layouts to choose according to some cost model.
- Shape output_shape(row_major_shape(convolution->shape()));
- Shape input_shape(row_major_shape(lhs_instruction->shape()));
- Shape filter_shape(row_major_shape(rhs_instruction->shape()));
+ Shape output_shape(RowMajorShape(convolution->shape()));
+ Shape input_shape(RowMajorShape(lhs_instruction->shape()));
+ Shape filter_shape(RowMajorShape(rhs_instruction->shape()));
// Set layouts of the instructions' shapes.
TF_RETURN_IF_ERROR(
constraints->SetOperandLayout(filter_shape, convolution, 1));
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(output_shape, convolution));
- } else if (should_make_rhs_col_major(*instruction)) {
+ } else if (ShouldMakeRhsColMajor(&cache, *instruction)) {
auto* dot = instruction;
const auto& rhs_shape = dot->operand(1)->shape();
TF_RETURN_IF_ERROR(
- constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1));
+ constraints->SetOperandLayout(ColMajorShape(rhs_shape), dot, 1));
} else if (PotentiallyImplementedAsEigenDot(*instruction)) {
const HloInstruction* dot = instruction;
// In order to implement `dot` with Eigen dot, the layouts of the lhs,
//
// These constraints are not hard constraints. Ideally, we should decide
// which layouts to choose according to some cost model.
- Shape output_shape(row_major_shape(dot->shape()));
+ Shape output_shape(RowMajorShape(dot->shape()));
const HloInstruction* lhs_instruction = dot->operand(0);
- Shape lhs_shape(row_major_shape(lhs_instruction->shape()));
+ Shape lhs_shape(RowMajorShape(lhs_instruction->shape()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0));
// dot is a kDot or a kTransposeDot fusion node. In the latter case, if
// it represents X @ X, it may have just one operand.
if (dot->operand_count() > 1) {
const HloInstruction* rhs_instruction = dot->operand(1);
- Shape rhs_shape(row_major_shape(rhs_instruction->shape()));
+ Shape rhs_shape(RowMajorShape(rhs_instruction->shape()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1));
}
continue;
}
Shape operand_shape(
- row_major_shape(instruction->operand(operand_no)->shape()));
+ RowMajorShape(instruction->operand(operand_no)->shape()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
operand_shape, instruction, operand_no));
}