[TF:XLA:GPU] Allow the use of linear address when there are size one dimensions
authorBixia Zheng <bixia@google.com>
Mon, 7 May 2018 19:15:52 +0000 (12:15 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 7 May 2018 23:55:04 +0000 (16:55 -0700)
in a tensor.

The current implementation of EmitArrayElementAddress incorrectly concludes
that having a size one dimension in a tensor indicates broadcasting is needed
and the linear address can't be used to access the tensor. We fix this by
leaving LinearValidOnShape to decide whether the linear address can be used to
access the tensor. This enables the vectorization of loads/stores in unrolled
elementwise op kernels when other criteria are met.

Add a test case.

PiperOrigin-RevId: 195701194

tensorflow/compiler/xla/service/llvm_ir/ir_array.cc

index 3312a88..7323abe 100644 (file)
@@ -333,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(
   }
   CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_));
 
-  std::vector<llvm::Value*> actual_index;
-  bool is_implicit_broadcast = false;
-  // We perform broadcasting when the operand shape has dimension(s) of size
-  // 1. In this case we fix the index value for that dimension to zero. This
-  // effectively broadcasts along this dimension.
-  for (int64 i = 0; i < index.size(); ++i) {
-    auto dim = shape_->dimensions(i);
-    actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
-    is_implicit_broadcast |= dim == 1;
-  }
-
-  if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
+  if (index.LinearValidOnShape(*shape_)) {
     llvm::Module* module =
         ir_builder->GetInsertBlock()->getParent()->getParent();
     return ir_builder->CreateInBoundsGEP(
@@ -354,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress(
         {index.linear()}, llvm_ir::AsStringRef(name));
   }
 
+  std::vector<llvm::Value*> actual_index;
+  for (int64 i = 0; i < index.size(); ++i) {
+    // When dimension i is of size 1, LLVM optimization is able to replace
+    // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
+    // produce better code in some cases.
+    auto dim = shape_->dimensions(i);
+    actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
+  }
+
   // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
   // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
   // should be computed by