Reuse the linear index when broadcasting a contiguous range of dimensions.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 14:26:13 +0000 (07:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 14:30:43 +0000 (07:30 -0700)
This potentially allows us to get rid of additional mod and div operations.

PiperOrigin-RevId: 188719238

tensorflow/compiler/xla/service/elemental_ir_emitter.cc
tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
tensorflow/compiler/xla/service/llvm_ir/ir_array.h

index 111c295..b6a0903 100644 (file)
@@ -1522,15 +1522,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
     case HloOpcode::kBroadcast:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+        const HloInstruction* operand = hlo->operand(0);
         // The `dimensions` member of the broadcast instruction maps from
         // input dimensions to output dimensions.
-        const HloInstruction* operand = hlo->operand(0);
-        int64 rank = ShapeUtil::Rank(operand->shape());
-        IrArray::Index source_index(rank);
-        for (int64 i = 0; i < rank; ++i) {
-          source_index[i] = target_index[hlo->dimensions(i)];
-        }
-        return operand_to_generator.at(operand)(source_index);
+        return operand_to_generator.at(
+            operand)(target_index.SourceIndexOfBroadcast(
+            hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_));
       };
     case HloOpcode::kSlice:
       return [this, hlo, &operand_to_generator](
index d444c1d..3312a88 100644 (file)
@@ -241,6 +241,69 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast(
   return Index(multi_index, linear_index, operand_shape);
 }
 
+IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
+    const Shape& shape, const Shape& operand_shape,
+    tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+    llvm::IRBuilder<>* builder) const {
+  int64 rank = ShapeUtil::Rank(operand_shape);
+  std::vector<llvm::Value*> source_index(rank);
+  for (int64 i = 0; i < rank; ++i) {
+    source_index[i] = multidim_[dimension_mapping[i]];
+  }
+  if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
+      !LayoutUtil::HasLayout(shape)) {
+    return Index(source_index);
+  }
+  // High-level idea: we can reuse the linear index if the broadcasted
+  // dimensions are contiguous, and this part of the operation is a bitcast.
+  // The other dimensions can be masked out with a div and a mod operation.
+  std::vector<int64> logical_to_physical =
+      LayoutUtil::MakeLogicalToPhysical(shape.layout());
+  int64 output_rank = ShapeUtil::Rank(shape);
+  // The minimum physical dimension that is broadcasted.
+  int64 min_broadcasted_dimension = output_rank;
+  // The maximum physical dimension that is broadcasted.
+  int64 max_broadcasted_dimension = -1;
+  for (int64 i = 0; i < rank; ++i) {
+    int64 physical_dim = logical_to_physical[dimension_mapping[i]];
+    min_broadcasted_dimension =
+        std::min(min_broadcasted_dimension, physical_dim);
+    max_broadcasted_dimension =
+        std::max(max_broadcasted_dimension, physical_dim);
+  }
+  bool contiguous_broadcast_dimensions =
+      max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
+  if (!contiguous_broadcast_dimensions) {
+    return Index(source_index);
+  }
+  // Check if the mapped dimensions are a bitcast.
+  std::vector<int64> operand_logical_to_physical =
+      LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
+  for (int64 i = 0; i < rank; ++i) {
+    if (operand_logical_to_physical[i] !=
+        logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
+      return Index(source_index);
+    }
+  }
+  llvm::Value* linear = linear_;
+  int64 divisor = 1;
+  for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
+    divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
+  }
+  if (divisor > 1) {
+    linear = builder->CreateUDiv(linear, builder->getInt64(divisor));
+  }
+  if (min_broadcasted_dimension > 0) {
+    int64 mod = 1;
+    for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
+         ++i) {
+      mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
+    }
+    linear = builder->CreateURem(linear, builder->getInt64(mod));
+  }
+  return Index(source_index, linear, operand_shape);
+}
+
 llvm::Value* IrArray::Index::Linearize(
     tensorflow::gtl::ArraySlice<int64> dimensions,
     llvm::IRBuilder<>* builder) const {
index faa92d6..06cfb2a 100644 (file)
@@ -134,10 +134,17 @@ class IrArray {
         llvm::IRBuilder<>* builder) const;
 
     // Given that "this" is the target index of a bitcast from `operand_shape`
-    // to `shape` with the given dimension mapping, returns the source index.
+    // to `shape`, returns the source index.
     Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
                                llvm::IRBuilder<>* builder) const;
 
+    // Given that "this" is the target index of a broadcast from `operand_shape`
+    // to `shape` with the given dimension mapping, returns the source index.
+    Index SourceIndexOfBroadcast(
+        const Shape& shape, const Shape& operand_shape,
+        tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+        llvm::IRBuilder<>* builder) const;
+
     // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
     // returns the index into the sole dimension 0 of the new shape.
     llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,