Add a type-erased broadcast implementation to xla::Literal
authorSanjoy Das <sanjoy@google.com>
Sat, 26 May 2018 00:46:19 +0000 (17:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 26 May 2018 00:48:50 +0000 (17:48 -0700)
And use this in HLO evaluator.  Since broadcast only moves bytes around we don't
need a type specialized implementation.

I'll use this in a later change.

PiperOrigin-RevId: 198128524

tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/literal_util_test.cc
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_evaluator.h
tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
tensorflow/compiler/xla/shape_util.h

index 4c56076..7563cc1 100644 (file)
@@ -807,6 +807,47 @@ std::unique_ptr<Literal> LiteralBase::Relayout(
   return result;
 }
 
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+    const Shape& result_shape,
+    tensorflow::gtl::ArraySlice<int64> dimensions) const {
+  if (!ShapeUtil::IsArray(shape())) {
+    return InvalidArgument("Broadcast only supports arrays.");
+  }
+
+  for (int64 i = 0; i < dimensions.size(); i++) {
+    TF_RET_CHECK(shape().dimensions(i) ==
+                 result_shape.dimensions(dimensions[i]));
+  }
+
+  std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+
+  // scratch_source_index is temporary storage space for the computed index into
+  // the input literal.  We put it here to avoid allocating an std::vector in
+  // every iteration of ShapeUtil::ForEachIndex.
+  std::vector<int64> scratch_source_index(shape().dimensions_size());
+
+  char* dest_data = static_cast<char*>(result->untyped_data());
+  const char* source_data = static_cast<const char*>(untyped_data());
+  const int64 primitive_size =
+      ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+  ShapeUtil::ForEachIndex(
+      result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+        for (int64 i = 0; i < dimensions.size(); ++i) {
+          scratch_source_index[i] = output_index[dimensions[i]];
+        }
+        int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+            result_shape, output_index);
+        int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+            shape(), scratch_source_index);
+        memcpy(dest_data + primitive_size * dest_index,
+               source_data + primitive_size * source_index, primitive_size);
+        return true;
+      });
+
+  return std::move(result);
+}
+
 StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
     tensorflow::gtl::ArraySlice<int64> dimensions) const {
   if (!ShapeUtil::IsArray(shape())) {
index 609dc7a..2ca9060 100644 (file)
@@ -277,6 +277,12 @@ class LiteralBase {
   StatusOr<std::unique_ptr<Literal>> Reshape(
       tensorflow::gtl::ArraySlice<int64> dimensions) const;
 
+  // Creates a new literal by broadcasting this literal with `dimensions` to
+  // yield a literal of shape `result_shape`.
+  StatusOr<std::unique_ptr<Literal>> Broadcast(
+      const Shape& result_shape,
+      tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
   // Creates a new literal by reordering the dimensions of this literal.
   // The given `permutation` must be a permutation of the dimension numbers
   // in the original literal, and it specifies the order of the new dimensions
index 77f979a..f127cee 100644 (file)
@@ -1810,5 +1810,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
       tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
 }
 
+TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
+  std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<Literal> broadcasted_literal,
+      literal->Broadcast(
+          /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+          /*dimensions=*/{0}));
+  EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 1}, {2, 2}}));
+}
+
+TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
+  std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<Literal> broadcasted_literal,
+      literal->Broadcast(
+          /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+          /*dimensions=*/{1}));
+  EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 2}, {1, 2}}));
+}
+
+TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
+  std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(9);
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<Literal> broadcasted_literal,
+      literal->Broadcast(
+          /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+          /*dimensions=*/{}));
+  EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int32>({{9, 9}, {9, 9}}));
+}
+
 }  // namespace
 }  // namespace xla
index fa59a5f..2a8de02 100644 (file)
@@ -859,6 +859,28 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
   return Status::OK();
 }
 
+Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
+  const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
+
+  TF_RET_CHECK(broadcast->dimensions().size() ==
+               ShapeUtil::Rank(operand.shape()))
+      << "broadcast dimensions is of size: " << broadcast->dimensions().size()
+      << " and rank of operand_to_broadcast is: "
+      << ShapeUtil::Rank(operand.shape());
+  // Checks that operand's dimensions are the same as the broadcast's
+  // dimensions along the dimensions to be broadcasted.
+  for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
+    TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
+                 operand.shape().dimensions(i));
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      evaluated_[broadcast],
+      operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
+
+  return Status::OK();
+}
+
 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
   const auto result_shape = get_tuple_element->shape();
   const int64 index = get_tuple_element->tuple_index();
index 566d53a..2b72ff1 100644 (file)
@@ -166,6 +166,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 
   Status HandleSelect(HloInstruction* select) override;
 
+  Status HandleBroadcast(HloInstruction* broadcast) override;
+
   // Returns the already-evaluated literal result for the instruction.
   // A Constant instruction is considered evaluated and its literal will be
   // returned directly without looking up the cache.
index e37d651..82ee77e 100644 (file)
@@ -161,36 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
     return HandleRound<ReturnT>(round);
   }
 
-  Status HandleBroadcast(HloInstruction* broadcast) override {
-    const Literal& operand_to_broadcast =
-        parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
-    std::vector<int64> broadcast_indices(
-        ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
-
-    TF_RET_CHECK(broadcast->dimensions().size() ==
-                 ShapeUtil::Rank(operand_to_broadcast.shape()))
-        << "broadcast dimensions is of size: " << broadcast->dimensions().size()
-        << " and rank of operand_to_broadcast is: "
-        << ShapeUtil::Rank(operand_to_broadcast.shape());
-    // Checks that operand's dimensions are the same as the broadcast's
-    // dimensions along the dimensions to be broadcasted.
-    for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
-      TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
-                   operand_to_broadcast.shape().dimensions(i));
-    }
-
-    auto output = MakeUnique<Literal>(broadcast->shape());
-    TF_RETURN_IF_ERROR(output->Populate<ReturnT>(
-        [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
-          for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
-            broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
-          }
-          return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
-        }));
-    parent_->evaluated_[broadcast] = std::move(output);
-    return Status::OK();
-  }
-
   template <
       typename NativeT,
       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
index 6f57658..cf40068 100644 (file)
@@ -27,6 +27,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
@@ -629,6 +630,28 @@ class ShapeUtil {
         .IgnoreError();
   }
 
+  // These convenience wrappers don't take `base`, `count` and `incr`
+  // explicitly, but iterate over every element in `shape` instead.
+
+  template <typename FnType>
+  static Status ForEachIndexWithStatus(const Shape& shape,
+                                       const FnType& visitor_function) {
+    std::vector<int64> base(shape.dimensions_size());
+    std::vector<int64> incr(shape.dimensions_size(), 1);
+    return ForEachIndexWithStatus(shape, base,
+                                  /*count=*/AsInt64Slice(shape.dimensions()),
+                                  incr, visitor_function);
+  }
+
+  template <typename FnType>
+  static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
+    ForEachIndexWithStatus(shape,
+                           [&](tensorflow::gtl::ArraySlice<int64> indices) {
+                             return StatusOr<bool>(visitor_function(indices));
+                           })
+        .IgnoreError();
+  }
+
   // A parallel version of ForEachIndex(WithStatus). This can only be used if
   // the visitor_function is thread-safe and the order of iteration does not
   // matter.