Tile on the M dimension in GEBP.
authorSanjoy Das <sanjoy@google.com>
Tue, 29 May 2018 18:14:43 +0000 (11:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 18:20:44 +0000 (11:20 -0700)
After this change the inner reduction loop in GEBP multiplies a tile from the
LHS and a tile from the RHS to get a result tile.

PiperOrigin-RevId: 198425769

tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc

index 48bea7c..c704105 100644 (file)
@@ -52,7 +52,7 @@ class MemoryTile {
   MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
              llvm::Value* matrix, int64 matrix_size_along_minor_dim,
              llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
-      : vsl_(vsl) {
+      : vsl_(vsl), ir_builder_(ir_builder) {
     pointers_.reserve(tile_size_along_major_dim);
     for (int64 i = 0; i < tile_size_along_major_dim; i++) {
       llvm::Value* total_offset = ir_builder->CreateMul(
@@ -62,9 +62,10 @@ class MemoryTile {
     }
   }
 
-  // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
-  // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
-  // minor dimension.
+  // Load a tile consisting of `tile_size_along_major_dim` vectors from position
+  // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
+  //
+  // Note: `major_dim_offset` is a parameter to the constructor.
   std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
     std::vector<llvm::Value*> result;
     result.reserve(pointers_.size());
@@ -74,8 +75,42 @@ class MemoryTile {
     return result;
   }
 
+  // Stores `tile` to position {major: `major_dim_offset`, minor:
+  // `minor_dim_offset`}.
+  //
+  // Note: `major_dim_offset` is a parameter to the constructor.
+  void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
+                 llvm::Value* minor_dim_offset) const {
+    CHECK_EQ(tile.size(), pointers_.size());
+    for (int64 i = 0; i < pointers_.size(); i++) {
+      vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
+    }
+  }
+
+  // Loads a tile of size [`tile_size_along_major_dim`,
+  // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
+  // minor: `minor_dim_offset`} and then broadcasts each element into a vector
+  // of size vsl_.vector_size().  The (i,j)'th element of the return value is
+  // the (i,j)'th element in the tile broadcasted into an LLVM vector.
+  //
+  // Note: `major_dim_offset` is a parameter to the constructor.
+  std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
+      llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const {
+    std::vector<std::vector<llvm::Value*>> result;
+    result.resize(pointers_.size());
+    for (int64 i = 0; i < pointers_.size(); i++) {
+      for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
+        result[i].push_back(vsl_->LoadBroadcast(
+            pointers_[i], ir_builder_->CreateAdd(minor_dim_offset,
+                                                 ir_builder_->getInt64(j))));
+      }
+    }
+    return result;
+  }
+
  private:
   VectorSupportLibrary* vsl_;
+  llvm::IRBuilder<>* ir_builder_;
   std::vector<llvm::Value*> pointers_;
 };
 
@@ -633,38 +668,44 @@ class MatrixMatrixBlockPanelEmitter {
   // `min_vectorization_width` is the smallest vector width the emitter will use
   // -- below that it will devolve to using a scalar loop.
   //
-  // `k_tiling_factor` is the number of elements along the reduction dimensions
-  // that we will attempt to process at once.
+  // The innermost reduction loop executes the matrix multiply in tiles of size
+  // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
+  // <vectorization width>] in the RHS.
   class Config {
    public:
     explicit Config(PrimitiveType scalar_type, Dimensions dims,
                     int64 max_vectorization_width,
-                    int64 min_vectorization_width, int64 k_tiling_factor)
+                    int64 min_vectorization_width, int64 tile_size_m,
+                    int64 tile_size_k)
         : scalar_type_(scalar_type),
           dims_(dims),
           max_vectorization_width_(max_vectorization_width),
           min_vectorization_width_(min_vectorization_width),
-          k_tiling_factor_(k_tiling_factor) {}
+          tile_size_m_(tile_size_m),
+          tile_size_k_(tile_size_k) {}
 
     string GetCacheKey() const {
       return tensorflow::strings::StrCat(
           "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
           "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
-          k_tiling_factor());
+          tile_size_m(), "_", tile_size_k());
     }
 
     PrimitiveType scalar_type() const { return scalar_type_; }
     Dimensions dims() const { return dims_; }
     int64 max_vectorization_width() const { return max_vectorization_width_; }
     int64 min_vectorization_width() const { return min_vectorization_width_; }
-    int64 k_tiling_factor() const { return k_tiling_factor_; }
+
+    int64 tile_size_m() const { return tile_size_m_; }
+    int64 tile_size_k() const { return tile_size_k_; }
 
    private:
     PrimitiveType scalar_type_;
     Dimensions dims_;
     int64 max_vectorization_width_;
     int64 min_vectorization_width_;
-    int64 k_tiling_factor_;
+    int64 tile_size_m_;
+    int64 tile_size_k_;
   };
 
   // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
@@ -682,30 +723,37 @@ class MatrixMatrixBlockPanelEmitter {
           IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
     CHECK(min_vectorization_width() > 0 &&
           IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
-    CHECK_GT(k_tiling_factor(), 0);
+    CHECK_GT(tile_size_k(), 0);
   }
 
   void Emit();
 
  private:
-  // We can only iterate the `n` dimension for an extent that is divisible by
-  // the vectorization width.  So we emit an outer loop that first processes the
-  // largest extent in `n` that is divisible by max_vectorization_width, then
-  // the largest remaining extent that is divisible by max_vectorization_width /
-  // 2 etc.  This function emits that outermost loop.
-  void EmitChunkedLoopOverN();
+  // This emits a loop that loops over the `n` dimension in multiples of
+  // `max_vectorization_width` as much as possible and then emits a remainder
+  // epilogue.
+  void EmitLoopOverN();
 
   // This emits a loop that loops over the `k` dimension in multiples of
-  // `k_tiling_factor` as much as possible and then emits a remainder epilogue.
+  // `tile_size_k` as much as possible and then emits a remainder epilogue.
   void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
                      llvm::Value* n_end);
 
-  // This emits the inner reduction loop.  This inner reduction loop processes
-  // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension
-  // and [`n_start`, `n_end`) in the `n` dimension.
-  void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start,
-                     llvm::Value* k_end, llvm::Value* n_start,
-                     llvm::Value* n_end, VectorSupportLibrary* vsl);
+  // This emits a loop that loops over the `m` dimension in multiples of
+  // `tile_size_m` as much as possible and then emits a remainder epilogue.
+  void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k,
+                     llvm::Value* k_start, llvm::Value* k_end,
+                     llvm::Value* n_start, llvm::Value* n_end);
+
+  // This emits the inner reduction loop.  This inner reduction loop multiplies
+  // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the
+  // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size
+  // [`tile_size_m`, vls->vector_width()] in the result.
+  void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k,
+                              llvm::Value* k_start, llvm::Value* k_end,
+                              llvm::Value* n_start, llvm::Value* n_end,
+                              int64 tile_size_m, llvm::Value* m_start,
+                              llvm::Value* m_end);
 
   llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); }
 
@@ -718,7 +766,8 @@ class MatrixMatrixBlockPanelEmitter {
   int64 min_vectorization_width() const {
     return config().min_vectorization_width();
   }
-  int64 k_tiling_factor() const { return config().k_tiling_factor(); }
+  int64 tile_size_m() const { return config().tile_size_m(); }
+  int64 tile_size_k() const { return config().tile_size_k(); }
   PrimitiveType scalar_type() const { return config().scalar_type(); }
 
   llvm::Value* lhs_;
@@ -730,9 +779,15 @@ class MatrixMatrixBlockPanelEmitter {
   KernelSupportLibrary ksl_;
 };
 
-void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
+void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); }
+
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
+  // We can only iterate the `n` dimension for an extent that is divisible by
+  // the vectorization width.  So we emit an outer loop that first processes the
+  // largest extent in `n` that is divisible by max_vectorization_width, then
+  // the largest remaining extent that is divisible by max_vectorization_width /
+  // 2 etc.
 
-void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
   int64 current_vectorization_width = max_vectorization_width();
   int64 n_start = 0;
   while (n_start != dims().n() &&
@@ -761,16 +816,30 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
                                                   llvm::Value* n_start,
                                                   llvm::Value* n_end) {
   int64 k_start = 0;
-  int64 k_end = dims().k() - (dims().k() % k_tiling_factor());
+  int64 k_end = dims().k() - (dims().k() % tile_size_k());
   if (k_end != k_start) {
-    EmitInnerLoop(k_tiling_factor(), GetInt64(k_start), GetInt64(k_end),
-                  n_start, n_end, vsl);
+    EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
+                  n_start, n_end);
     k_start = k_end;
   }
 
   if (k_start != dims().k()) {
-    EmitInnerLoop(dims().k() - k_start, GetInt64(k_start), GetInt64(dims().k()),
-                  n_start, n_end, vsl);
+    EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start),
+                  GetInt64(dims().k()), n_start, n_end);
+  }
+}
+
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
+    VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
+    llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
+  const int64 m_end = dims().m() - dims().m() % tile_size_m();
+  EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+                         tile_size_m(), GetInt64(0), GetInt64(m_end));
+
+  if (m_end != dims().m()) {
+    EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+                           dims().m() - m_end, GetInt64(m_end),
+                           GetInt64(dims().m()));
   }
 }
 
@@ -778,11 +847,11 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
 //
 // Let the LHS be:
 //
-//   +---+---+---+
-//   | a | b | c | .
-//   +---+---+---+ .
-//   |   |   |   | .
-//   +---+---+---+
+//   +----+----+----+
+//   | a0 | b0 | c0 | .
+//   +----+----+----+ .
+//   | a1 | b1 | c1 | .
+//   +----+----+----+
 //     ..     ..
 //
 // and the RHS be:
@@ -796,72 +865,77 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
 //   +----+----+----+----+ .
 //     ......    ......
 //
-// and let k_tiling_factor be 3 and the vector width (implicitly denoted by
-// `vsl`) be 4.
+// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
+// by `vsl`) be 4.  Then we want to matrix multiply this tile to get a [2,4]
+// matrix that we can increment the result matrix by.
 //
-// Then we
+// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
+// 3 array, L, of dimension [2,3,4]:
 //
-//  1. broadcast the first row in LHS to 3 vectors of width 4
-//  2. elementwise multiply the RHS rows with these broadcasted vectors
-//  3. elementwise add them:
+//       L[0,_,_]           *      L[1,_,_]
+//                          *
+//   +----+----+----+----+  *  +----+----+----+----+
+//   | a0 | a0 | a0 | a0 |  *  | a1 | a1 | a1 | a1 |
+//   +----+----+----+----+  *  +----+----+----+----+
+//   | b0 | b0 | b0 | b0 |  *  | b1 | b1 | b1 | b1 |
+//   +----+----+----+----+  *  +----+----+----+----+
+//   | c0 | c0 | c0 | c0 |  *  | c1 | c1 | c1 | c1 |
+//   +----+----+----+----+  *  +----+----+----+----+
 //
-//   +---+---+---+---+   +----+----+----+----+
-//   | a | a | a | a | * | p0 | p1 | p2 | p3 |   +
-//   +---+---+---+---+   +----+----+----+----+
 //
-//   +---+---+---+---+   +----+----+----+----+
-//   | b | b | b | b | * | q0 | q1 | q2 | q3 |   +
-//   +---+---+---+---+   +----+----+----+----+
+// Then we FMA L[0,_,_] with the RHS to get the first row of the result and
+// L[1,_,_] with the RHS to get the second row of the result.  For example,
+// L[0,_,_] is computed as:
 //
-//   +---+---+---+---+   +----+----+----+----+
-//   | c | c | c | c | * | r0 | r1 | r2 | r3 |
-//   +---+---+---+---+   +----+----+----+----+
+//   +----+----+----+----+   +----+----+----+----+
+//   | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 |   +
+//   +----+----+----+----+   +----+----+----+----+
 //
-// to get:
+//   +----+----+----+----+   +----+----+----+----+
+//   | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 |   +
+//   +----+----+----+----+   +----+----+----+----+
 //
-//   +----------------+----------------+----------------+----------------+
-//   | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 |
-//   +----------------+----------------+----------------+----------------+
+//   +----+----+----+----+   +----+----+----+----+
+//   | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
+//   +----+----+----+----+   +----+----+----+----+
 //
-// which we increment into the appropriate region in the result.
-void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
-    int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
-    llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
-  ksl_.For("dot.m", 0, dims().m(), 1, [&](llvm::Value* m_i) {
-    // This outer loop iterates over all of the M dimension
-    llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
-        result_, /*offset_elements=*/m_i, /*scale=*/dims().n());
-    llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer(
-        lhs_, /*offset_elements=*/m_i, /*scale=*/dims().k());
-
-    ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
-      // broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
-      // <b,b,b,b> etc. in the diagram.
-      std::vector<llvm::Value*> broadcasted_a;
-      broadcasted_a.reserve(k_tiling_factor);
-      for (int i = 0; i < k_tiling_factor; i++) {
-        broadcasted_a.push_back(vsl->LoadBroadcast(
-            lhs_row_begin, ir_builder_->CreateAdd(GetInt64(i), k_i)));
-      }
-
-      // rhs_loader will be used to load the tile off of the RHS, denoted as
-      // <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
-      MemoryTile rhs_loader(vsl, ir_builder_, rhs_, dims().n(), k_i,
-                            k_tiling_factor);
+// to get:
+//
+//   +-------------------+-------------------+-------------------+---------
+//   | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 |  ...
+//   +-------------------+-------------------+-------------------+---------
+void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop(
+    VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
+    llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
+    int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
+  ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
+    MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_,
+                                  /*matrix_size_along_minor_dim=*/dims().n(),
+                                  /*major_dim_offset=*/m_i,
+                                  /*tile_size_along_major_dim=*/tile_size_m);
+    MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
+                               /*matrix_size_along_minor_dim=*/dims().k(),
+                               /*major_dim_offset=*/m_i,
+                               /*tile_size_along_major_dim=*/tile_size_m);
+
+    ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
+      MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i,
+                                 tile_size_k);
+      std::vector<std::vector<llvm::Value*>> lhs_tile =
+          lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
       ksl_.For(
           "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
-            // This loop iterates over the N dimension.  It loads the tile from
-            // RHS, does the FMA resulting in the
-            // <a*p0+b*q0+c*r0,a*p1+b*q1+c*r1,...> in the diagram and increments
-            // the result.
-            std::vector<llvm::Value*> tile = rhs_loader.LoadTile(n_i);
-            llvm::Value* result_accumulator =
-                vsl->LoadVector(result_row_begin, n_i);
-            for (int i = 0; i < tile.size(); i++) {
-              result_accumulator =
-                  vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator);
+            std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
+            std::vector<llvm::Value*> result_tile =
+                result_memory_tile.LoadTile(n_i);
+            for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
+              for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
+                result_tile[r_m_i] =
+                    vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
+                                result_tile[r_m_i]);
+              }
             }
-            vsl->StoreVector(result_accumulator, result_row_begin, n_i);
+            result_memory_tile.StoreTile(result_tile, n_i);
           });
     });
   });
@@ -955,7 +1029,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
       /*scalar_type=*/primitive_type,
       MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
       /*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
-      /*k_tiling_factor=*/8);
+      /*tile_size_m=*/3, /*tile_size_k=*/8);
 
   const bool enable_fast_math =
       hlo_module_config_.debug_options().xla_enable_fast_math();