case HloOpcode::kBroadcast:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+ const HloInstruction* operand = hlo->operand(0);
// The `dimensions` member of the broadcast instruction maps from
// input dimensions to output dimensions.
- const HloInstruction* operand = hlo->operand(0);
- int64 rank = ShapeUtil::Rank(operand->shape());
- IrArray::Index source_index(rank);
- for (int64 i = 0; i < rank; ++i) {
- source_index[i] = target_index[hlo->dimensions(i)];
- }
- return operand_to_generator.at(operand)(source_index);
+ return operand_to_generator.at(
+ operand)(target_index.SourceIndexOfBroadcast(
+ hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_));
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
return Index(multi_index, linear_index, operand_shape);
}
+IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
+ const Shape& shape, const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const {
+ int64 rank = ShapeUtil::Rank(operand_shape);
+ std::vector<llvm::Value*> source_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ source_index[i] = multidim_[dimension_mapping[i]];
+ }
+ if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
+ !LayoutUtil::HasLayout(shape)) {
+ return Index(source_index);
+ }
+ // High-level idea: we can reuse the linear index if the broadcasted
+ // dimensions are contiguous, and this part of the operation is a bitcast.
+ // The other dimensions can be masked out with a div and a mod operation.
+ std::vector<int64> logical_to_physical =
+ LayoutUtil::MakeLogicalToPhysical(shape.layout());
+ int64 output_rank = ShapeUtil::Rank(shape);
+ // The minimum physical dimension that is broadcasted.
+ int64 min_broadcasted_dimension = output_rank;
+ // The maximum physical dimension that is broadcasted.
+ int64 max_broadcasted_dimension = -1;
+ for (int64 i = 0; i < rank; ++i) {
+ int64 physical_dim = logical_to_physical[dimension_mapping[i]];
+ min_broadcasted_dimension =
+ std::min(min_broadcasted_dimension, physical_dim);
+ max_broadcasted_dimension =
+ std::max(max_broadcasted_dimension, physical_dim);
+ }
+ bool contiguous_broadcast_dimensions =
+ max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
+ if (!contiguous_broadcast_dimensions) {
+ return Index(source_index);
+ }
+ // Check if the mapped dimensions are a bitcast.
+ std::vector<int64> operand_logical_to_physical =
+ LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
+ for (int64 i = 0; i < rank; ++i) {
+ if (operand_logical_to_physical[i] !=
+ logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
+ return Index(source_index);
+ }
+ }
+ llvm::Value* linear = linear_;
+ int64 divisor = 1;
+ for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
+ divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
+ }
+ if (divisor > 1) {
+ linear = builder->CreateUDiv(linear, builder->getInt64(divisor));
+ }
+ if (min_broadcasted_dimension > 0) {
+ int64 mod = 1;
+ for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
+ ++i) {
+ mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
+ }
+ linear = builder->CreateURem(linear, builder->getInt64(mod));
+ }
+ return Index(source_index, linear, operand_shape);
+}
+
llvm::Value* IrArray::Index::Linearize(
tensorflow::gtl::ArraySlice<int64> dimensions,
llvm::IRBuilder<>* builder) const {
llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a bitcast from `operand_shape`
- // to `shape` with the given dimension mapping, returns the source index.
+ // to `shape`, returns the source index.
Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
llvm::IRBuilder<>* builder) const;
+ // Given that "this" is the target index of a broadcast from `operand_shape`
+ // to `shape` with the given dimension mapping, returns the source index.
+ Index SourceIndexOfBroadcast(
+ const Shape& shape, const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
+
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
// returns the index into the sole dimension 0 of the new shape.
llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,