[XLA] Parallelize HloEvaluator::HandleConvolution
authorMichael Kuperstein <mkuper@google.com>
Sun, 8 Apr 2018 22:37:26 +0000 (15:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 8 Apr 2018 22:39:50 +0000 (15:39 -0700)
This adds a parallel version of Literal::Populate, and uses it in the embarrassingly parallel convolution computation.

PiperOrigin-RevId: 192065277

tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/literal_util_test.cc
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/shape_util.h
tensorflow/compiler/xla/shape_util_test.cc

index a96a76f..33abbdb 100644 (file)
@@ -587,6 +587,12 @@ class Literal {
   template <typename NativeT, typename FnType>
   Status Populate(const FnType& generator);
 
+  // A parallel version of Populate(). This can be used if the generator is
+  // thread-safe and the values for the shape's different elements are
+  // independent.
+  template <typename NativeT, typename FnType>
+  Status PopulateParallel(const FnType& generator);
+
   // Fills this literal with the given value.
   template <typename NativeT>
   void PopulateWithValue(NativeT value);
@@ -785,6 +791,10 @@ class Literal {
   // buffer).
   void DeallocateBuffers();
 
+  // Implementation details shared between Populate() and PopulateParallel()
+  template <typename NativeT, typename FnType>
+  Status PopulateInternal(const FnType& generator, bool parallel);
+
   Shape shape_;
   ShapeTree<Piece> pieces_;
 
@@ -1276,7 +1286,7 @@ void Literal::PopulateSparse(SparseIndexArray indices,
 }
 
 template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
+Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
   const Shape& this_shape = shape();
   const int64 rank = ShapeUtil::Rank(this_shape);
   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
@@ -1286,11 +1296,11 @@ Status Literal::Populate(const FnType& generator) {
   if (rank > 0) {
     StrideConfig stride_config(this_shape, this_shape,
                                AsInt64Slice(this_shape.dimensions()));
-    DimensionVector minor_scan_indexes(rank, 0);
     int64 minor_dimension_size =
         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
 
     auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+      DimensionVector minor_scan_indexes(rank, 0);
       const int64 index =
           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
@@ -1298,17 +1308,35 @@ Status Literal::Populate(const FnType& generator) {
         minor_scan_indexes[stride_config.minor_dimension] = i;
         literal_data.at(index + i) = generator(minor_scan_indexes);
       }
-      return true;
     };
-    ShapeUtil::ForEachIndex(this_shape, stride_config.base,
-                            stride_config.dimensions, stride_config.step,
-                            init_function);
+    if (parallel) {
+      ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
+                                      stride_config.dimensions,
+                                      stride_config.step, init_function);
+    } else {
+      ShapeUtil::ForEachIndex(
+          this_shape, stride_config.base, stride_config.dimensions,
+          stride_config.step,
+          [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+            init_function(indexes);
+            return true;
+          });
+    }
   } else {
     // For scalars.
     literal_data.at(0) = generator({});
   }
   return Status::OK();
 }
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
+  return PopulateInternal<NativeT>(generator, /*parallel=*/false);
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateParallel(const FnType& generator) {
+  return PopulateInternal<NativeT>(generator, /*parallel=*/true);
+}
 
 template <typename NativeT>
 void Literal::PopulateWithValue(NativeT value) {
index 7627762..8b000f4 100644 (file)
@@ -1090,6 +1090,48 @@ TEST_F(LiteralUtilTest, Populate) {
   }
 }
 
+TEST_F(LiteralUtilTest, PopulateParallel) {
+  struct PopulateData {
+    std::vector<int64> dimensions;
+    std::vector<int64> layout;
+  } populate_data[] = {
+      {{}, {}},
+      {{0}, {0}},
+      {{16}, {0}},
+      {{2, 0}, {1, 0}},
+      {{4, 16}, {1, 0}},
+      {{21, 12}, {0, 1}},
+      {{6, 11, 17}, {2, 0, 1}},
+      {{6, 11, 5, 17}, {3, 2, 0, 1}},
+  };
+  for (const auto& data : populate_data) {
+    Shape shape = ShapeUtil::MakeShapeWithLayout(
+        primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
+        data.layout);
+    auto literal = Literal::CreateFromShape(shape);
+    auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+      // Offsets from linear index just to avoid R0 literals to be initialized
+      // with zero.
+      return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+                                                           indexes) +
+             17;
+    };
+    TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+
+    std::vector<int64> zero_base(data.dimensions.size(), 0);
+    std::vector<int64> step(data.dimensions.size(), 1);
+    bool matched = true;
+    auto check_function = [&](ArraySlice<int64> indexes) {
+      auto value = literal->Get<uint32>(indexes);
+      matched = matched && (value == generator(indexes));
+      return matched;
+    };
+    ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+                            check_function);
+    EXPECT_TRUE(matched);
+  }
+}
+
 TEST_F(LiteralUtilTest, ConvertR4) {
   // clang-format off
   auto original = Literal::CreateR4WithLayout<int8>({{
index 53ad890..b24757c 100644 (file)
@@ -998,18 +998,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
 
-    // Dimension number applicable for input (lhs).
-    const int64 input_batch_dim = dnums.input_batch_dimension();
-    const int64 input_z_dim = dnums.input_feature_dimension();
-    // Dimension number applicable for kernel (rhs).
-    const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
-    const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
-    // Dimension number applicable for output.
-    const int64 output_batch_dim = dnums.output_batch_dimension();
-    const int64 output_z_dim = dnums.output_feature_dimension();
-
-    const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
-
     std::vector<int64> window_dimension_sizes;
     for (auto i : dnums.kernel_spatial_dimensions()) {
       window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
@@ -1021,14 +1009,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
     DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
 
-    DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
-
     auto lhs_literal_data = lhs_literal.data<ReturnT>();
     auto rhs_literal_data = rhs_literal.data<ReturnT>();
 
-    auto func = [&](ArraySlice<int64> out_index) {
+    auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
+                 &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
+                 rhs_literal_data](ArraySlice<int64> out_index) {
+      // Dimension number applicable for input (lhs).
+      const int64 input_batch_dim = dnums.input_batch_dimension();
+      const int64 input_z_dim = dnums.input_feature_dimension();
+      // Dimension number applicable for kernel (rhs).
+      const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
+      const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
+      // Dimension number applicable for output.
+      const int64 output_batch_dim = dnums.output_batch_dimension();
+      const int64 output_z_dim = dnums.output_feature_dimension();
+
+      const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+
       ElementwiseT result_val = static_cast<ElementwiseT>(0);
-      std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
+      DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
+                                        0);
 
       // Convolve input feature with kernel.
       do {
@@ -1100,7 +1101,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     };
 
     auto result = Literal::CreateFromShape(result_shape);
-    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+    TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
 
     parent_->evaluated_[conv] = std::move(result);
     return Status::OK();
index 3e130a0..b9becf6 100644 (file)
@@ -28,8 +28,10 @@ limitations under the License.
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -583,34 +585,7 @@ class ShapeUtil {
                                        tensorflow::gtl::ArraySlice<int64> count,
                                        tensorflow::gtl::ArraySlice<int64> incr,
                                        const FnType& visitor_function) {
-    if (ShapeUtil::HasZeroElements(shape)) {
-      return Status::OK();
-    }
-    CHECK_EQ(Rank(shape), base.size());
-    CHECK_EQ(incr.size(), base.size());
-    CHECK_EQ(count.size(), base.size());
-    const int64 rank = LayoutUtil::MinorToMajor(shape).size();
-    // Allows handling R0 arrays, such that the visitor function will be called
-    // once with the proper empty indexes.
-    int64 n = -1;
-    std::vector<int64> indexes(base.begin(), base.end());
-    while (n < rank) {
-      TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
-      if (!should_continue) {
-        break;
-      }
-      // Increments dimensions in minor to major order.
-      for (n = 0; n < rank; ++n) {
-        int64 dim = LayoutUtil::Minor(shape.layout(), n);
-        indexes[dim] += incr[dim];
-        if (indexes[dim] < base[dim] + count[dim]) {
-          break;
-        }
-        indexes[dim] = base[dim];
-      }
-    }
-
-    return Status::OK();
+    return ForEachIndexInternal(shape, base, count, incr, visitor_function);
   }
 
   // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
@@ -642,11 +617,79 @@ class ShapeUtil {
         .IgnoreError();
   }
 
+  // A parallel version of ForEachIndex(WithStatus). This can only be used if
+  // the visitor_function is thread-safe and the order of iteration does not
+  // matter.
+  //
+  // visitor_function must be a callable of type
+  // void(ArraySlice<int64>) or compatible.
+  template <typename FnType>
+  static void ForEachIndexParallel(const Shape& shape,
+                                   tensorflow::gtl::ArraySlice<int64> base,
+                                   tensorflow::gtl::ArraySlice<int64> count,
+                                   tensorflow::gtl::ArraySlice<int64> incr,
+                                   const FnType& visitor_function) {
+    const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
+    tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test",
+                                        kNumThreads);
+    // If a pool is provided, ForEachIndexInternal can never fail.
+    CHECK(ForEachIndexInternal(
+              shape, base, count, incr,
+              [&visitor_function](tensorflow::gtl::ArraySlice<int64> indexes)
+                  -> StatusOr<bool> {
+                visitor_function(indexes);
+                return true;
+              },
+              &pool)
+              .ok());
+  }
+
  private:
   // Validates all of the non-layout properties of the shape -- this is a helper
   // used by both the layout-optional and layout-required public method.
   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
 
+  template <typename FnType>
+  static Status ForEachIndexInternal(
+      const Shape& shape, tensorflow::gtl::ArraySlice<int64> base,
+      tensorflow::gtl::ArraySlice<int64> count,
+      tensorflow::gtl::ArraySlice<int64> incr, const FnType& visitor_function,
+      tensorflow::thread::ThreadPool* pool = nullptr) {
+    if (ShapeUtil::HasZeroElements(shape)) {
+      return Status::OK();
+    }
+    CHECK_EQ(Rank(shape), base.size());
+    CHECK_EQ(incr.size(), base.size());
+    CHECK_EQ(count.size(), base.size());
+    const int64 rank = LayoutUtil::MinorToMajor(shape).size();
+    // Allows handling R0 arrays, such that the visitor function will be called
+    // once with the proper empty indexes.
+    int64 n = -1;
+    std::vector<int64> indexes(base.begin(), base.end());
+    while (n < rank) {
+      if (pool != nullptr) {
+        pool->Schedule(
+            [indexes, visitor_function] { visitor_function(indexes); });
+      } else {
+        TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
+        if (!should_continue) {
+          break;
+        }
+      }
+      // Increments dimensions in minor to major order.
+      for (n = 0; n < rank; ++n) {
+        int64 dim = LayoutUtil::Minor(shape.layout(), n);
+        indexes[dim] += incr[dim];
+        if (indexes[dim] < base[dim] + count[dim]) {
+          break;
+        }
+        indexes[dim] = base[dim];
+      }
+    }
+
+    return Status::OK();
+  }
+
   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
 };
 
index 424cfe3..13582a2 100644 (file)
@@ -624,6 +624,24 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) {
   EXPECT_EQ(invocations, 5);
 }
 
+TEST(ShapeUtilTest, ForEachIndexParallel) {
+  Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
+  int64 output[10][10];
+  int init = 5;
+  auto set_func = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+    output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1];
+  };
+
+  ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 10},
+                                  /*incr=*/{1, 1}, set_func);
+
+  for (int i = 0; i < 10; ++i) {
+    for (int j = 0; j < 10; ++j) {
+      EXPECT_EQ(output[i][j], init + i + j);
+    }
+  }
+}
+
 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
   // All output dimensions should be unmodified. One of the input dimensions is
   // modified because the input rank is larger by one.