BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
const LogicalBuffer& buffer,
- bool mandatory)
- : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) {
+ bool mandatory, bool dfs)
+ : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
}
OperandLayoutConstraint::OperandLayoutConstraint(
const ShapeLayout& shape_layout, const HloInstruction* instruction,
- int64 operand_no, bool mandatory)
- : LayoutConstraint(mandatory),
+ int64 operand_no, bool mandatory, bool dfs)
+ : LayoutConstraint(mandatory, dfs),
shape_layout_(shape_layout),
instruction_(instruction),
operand_no_(operand_no) {
Status LayoutConstraints::SetBufferLayout(const Layout& layout,
const LogicalBuffer& buffer,
- bool mandatory) {
+ bool mandatory, bool dfs) {
VLOG(3) << "SetBufferLayout : " << buffer << " : "
<< LayoutUtil::HumanString(layout);
if (!overwrite) {
iter = buffer_constraints_
.insert(std::make_pair(
- &buffer, BufferLayoutConstraint(layout, buffer, mandatory)))
+ &buffer,
+ BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
.first;
} else {
- iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true);
+ iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
const HloInstruction* instruction,
- int64 operand_no, bool mandatory) {
+ int64 operand_no, bool mandatory,
+ bool dfs) {
VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
<< operand_no << " : "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
if (iter == operand_constraints_.end()) {
auto pair = std::make_pair(
key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
- instruction, operand_no, mandatory));
+ instruction, operand_no, mandatory, dfs));
iter = operand_constraints_.insert(pair).first;
} else {
iter->second =
OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
- operand_no, /*mandatory=*/true);
+ operand_no, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
Status LayoutConstraints::SetArrayOperandLayout(
const Layout& layout, const HloInstruction* instruction, int64 operand_no,
- bool mandatory) {
+ bool mandatory, bool dfs) {
const HloInstruction* operand = instruction->operand(operand_no);
TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
Shape shape(operand->shape());
*shape.mutable_layout() = layout;
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
- return SetOperandLayout(shape, instruction, operand_no, mandatory);
+ return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
}
-Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
+Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
+ bool dfs) {
VLOG(3) << "SetResultLayout : "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
}
result_constraint_.reset(
- new ResultLayoutConstraint(ShapeLayout(shape_with_layout)));
+ new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
added_constraints_.push_back(result_constraint_.get());
return Status::OK();
}
Status LayoutConstraints::SetInstructionLayout(
- const Shape& shape_with_layout, const HloInstruction* instruction) {
+ const Shape& shape_with_layout, const HloInstruction* instruction,
+ bool mandatory, bool dfs) {
VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
// instruction.
return ShapeUtil::ForEachSubshapeWithStatus(
shape_with_layout,
- [this, instruction](const Shape& subshape,
- const ShapeIndex& index) -> Status {
+ [this, instruction, mandatory](const Shape& subshape,
+ const ShapeIndex& index) -> Status {
// The precondition for this method is that the instruction defines all
// buffers in its output.
auto buffers =
CHECK_EQ(buffers[0]->instruction(), instruction);
if (ShapeUtil::IsArray(subshape)) {
- return SetBufferLayout(subshape.layout(), *buffers[0]);
+ return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
} else {
return Status::OK();
}
// Constrain the input to the Outfeed instruction to be the expected
// layout of the Outfeed.
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- instruction->outfeed_shape(), instruction, 0,
- /*mandatory=*/true));
+ instruction->outfeed_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kParameter) {
// Parameter layouts must match the respective layout in
// ComputationLayout.
{0}));
Shape new_shape = channel_constraints->LayoutShapeForChannel(
recv_buffer_shape, instruction->channel_id());
- TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
- new_shape.layout(), *buffer, /*mandatory=*/true));
+ TF_RETURN_IF_ERROR(
+ constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
}
}
for (int64 i = 0; i < instruction->operand_count(); ++i) {
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
called_computation_layout.parameter_layout(i).shape(), instruction,
- i, /*mandatory=*/true));
+ i));
}
} else if (instruction->opcode() == HloOpcode::kWhile) {
// Layout of input and output of kWhile instruction must be equal and must
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
body_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- body_layout.result_shape(), instruction, 0,
- /*mandatory=*/true));
+ body_layout.result_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
if (!CustomCallRequiresMajorFirstLayout(instruction)) {
continue;
operand_shape.element_type(),
AsInt64Slice(operand_shape.dimensions()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- row_major_operand_shape, instruction, i, /*mandatory=*/true));
+ row_major_operand_shape, instruction, i));
}
}
}
auto add_new_constraints_to_worklist = [constraints, &worklist]() {
// Add constraints to the front of the deque for DFS ordering.
for (auto* constraint : constraints->ConsumeAddedConstraints()) {
- worklist.push_front(constraint);
+ if (constraint->dfs()) {
+ worklist.push_front(constraint);
+ } else {
+ worklist.push_back(constraint);
+ }
}
};
add_new_constraints_to_worklist();
// Add any backend-specific constraints.
TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
- // Propagates layouts from an HLO to its neighbors.
+ // Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
// While any unconstrained buffers remain, pick an arbitrary buffer, give it a
// Assign layouts to computations in an order such that a callee computation
// is handled before its caller computation. This ensures that the layout of
// all callers of a computation will agree.
+ std::list<HloComputation*> computation_post_order =
+ module->MakeComputationPostOrder();
for (auto* computation : module->MakeComputationPostOrder()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
// Clear existing layouts of the instructions. All layouts must be assigned
// by the LayoutAssignment pass, except for those on infeeds, parameters,
// and the computation result. The latter two are specified in
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
-
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(
*entry_computation_layout_, *points_to_analysis,
module->entry_computation(), channel_layout_constraints_));
- } else if (computation->IsFusionComputation()) {
- continue;
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());
// Setting all embedded computations to the default layout is potentially
// gathered together in LayoutConstraints object.
class LayoutConstraint {
public:
- LayoutConstraint(bool mandatory) : mandatory_(mandatory) {}
+ LayoutConstraint(bool mandatory, bool dfs)
+ : mandatory_(mandatory), dfs_(dfs) {}
virtual ~LayoutConstraint() = default;
virtual string ToString() const = 0;
// True if this constraint cannot be overwritten by a different constraint.
bool mandatory() const { return mandatory_; }
+ // When true, propagate in DFS. When false, constraint will propagate in BFS.
+ bool dfs() const { return dfs_; }
+
private:
bool mandatory_;
+ bool dfs_;
};
std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
class BufferLayoutConstraint : public LayoutConstraint {
public:
BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
- bool mandatory);
+ bool mandatory, bool dfs);
const LogicalBuffer& buffer() const { return *buffer_; }
const Layout& layout() const { return layout_; }
public:
OperandLayoutConstraint(const ShapeLayout& shape_layout,
const HloInstruction* instruction, int64 operand_no,
- bool mandatory);
+ bool mandatory, bool dfs);
const ShapeLayout& shape_layout() const { return shape_layout_; }
const HloInstruction* instruction() const { return instruction_; }
// Constraint on the layout of the result of the entry computation.
class ResultLayoutConstraint : public LayoutConstraint {
public:
- explicit ResultLayoutConstraint(const ShapeLayout& shape_layout)
- : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {}
+ explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
+ bool dfs = false)
+ : LayoutConstraint(/*mandatory=*/true, dfs),
+ shape_layout_(shape_layout) {}
const ShapeLayout& shape_layout() const { return shape_layout_; }
string ToString() const override;
// operand of the instruction, or the layout of the result of the computation,
// respectively.
Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
- bool mandatory = true);
+ bool mandatory = true, bool dfs = true);
Status SetOperandLayout(const Shape& shape_with_layout,
const HloInstruction* instruction, int64 operand_no,
- bool mandatory = true);
- Status SetResultLayout(const Shape& shape_with_layout);
+ bool mandatory = true, bool dfs = true);
+ Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true);
// Convenience wrapper around SetOperandLayout for setting the layout of a
// operand using a Layout object. The operand must be array-shaped.
Status SetArrayOperandLayout(const Layout& layout,
const HloInstruction* instruction,
- int64 operand_no, bool mandatory = true);
+ int64 operand_no, bool mandatory = true,
+ bool dfs = true);
// Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
// created by the instruction to the layouts in the given shape. The
// instruction must define every logical buffer in its output.
Status SetInstructionLayout(const Shape& shape_with_layout,
- const HloInstruction* instruction);
+ const HloInstruction* instruction,
+ bool mandatory = true, bool dfs = true);
// Returns true if any buffer in the given operand is forwarded to the output
// of the given instruction. For example, the Tuple instruction forwards the