From 4a920dcea404a34d9edc558e1360fe775dcabbe1 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 29 May 2018 11:14:43 -0700 Subject: [PATCH] Tile on the M dimension in GEBP. After this change the inner reduction loop in GEBP multiplies a tile from the LHS and a tile from the RHS to get a result tile. PiperOrigin-RevId: 198425769 --- .../compiler/xla/service/cpu/dot_op_emitter.cc | 262 +++++++++++++-------- 1 file changed, 168 insertions(+), 94 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 48bea7c..c704105 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -52,7 +52,7 @@ class MemoryTile { MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { + : vsl_(vsl), ir_builder_(ir_builder) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { llvm::Value* total_offset = ir_builder->CreateMul( @@ -62,9 +62,10 @@ class MemoryTile { } } - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. std::vector LoadTile(llvm::Value* minor_dim_offset) const { std::vector result; result.reserve(pointers_.size()); @@ -74,8 +75,42 @@ class MemoryTile { return result; } + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(tensorflow::gtl::ArraySlice tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, + ir_builder_->getInt64(j)))); + } + } + return result; + } + private: VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* ir_builder_; std::vector pointers_; }; @@ -633,38 +668,44 @@ class MatrixMatrixBlockPanelEmitter { // `min_vectorization_width` is the smallest vector width the emitter will use // -- below that it will devolve to using a scalar loop. // - // `k_tiling_factor` is the number of elements along the reduction dimensions - // that we will attempt to process at once. + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. class Config { public: explicit Config(PrimitiveType scalar_type, Dimensions dims, int64 max_vectorization_width, - int64 min_vectorization_width, int64 k_tiling_factor) + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) : scalar_type_(scalar_type), dims_(dims), max_vectorization_width_(max_vectorization_width), min_vectorization_width_(min_vectorization_width), - k_tiling_factor_(k_tiling_factor) {} + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} string GetCacheKey() const { return tensorflow::strings::StrCat( "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), "_", max_vectorization_width(), "_", min_vectorization_width(), "_", - k_tiling_factor()); + tile_size_m(), "_", tile_size_k()); } PrimitiveType scalar_type() const { return scalar_type_; } Dimensions dims() const { return dims_; } int64 max_vectorization_width() const { return max_vectorization_width_; } int64 min_vectorization_width() const { return min_vectorization_width_; } - int64 k_tiling_factor() const { return k_tiling_factor_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } private: PrimitiveType scalar_type_; Dimensions dims_; int64 max_vectorization_width_; int64 min_vectorization_width_; - int64 k_tiling_factor_; + int64 tile_size_m_; + int64 tile_size_k_; }; // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies @@ -682,30 +723,37 @@ class MatrixMatrixBlockPanelEmitter { IsPowerOfTwo(static_cast(max_vectorization_width()))); CHECK(min_vectorization_width() > 0 && IsPowerOfTwo(static_cast(min_vectorization_width()))); - CHECK_GT(k_tiling_factor(), 0); + CHECK_GT(tile_size_k(), 0); } void Emit(); private: - // We can only iterate the `n` dimension for an extent that is divisible by - // the vectorization width. So we emit an outer loop that first processes the - // largest extent in `n` that is divisible by max_vectorization_width, then - // the largest remaining extent that is divisible by max_vectorization_width / - // 2 etc. This function emits that outermost loop. - void EmitChunkedLoopOverN(); + // This emits a loop that loops over the `n` dimension in multiples of + // `max_vectorization_width` as much as possible and then emits a remainder + // epilogue. + void EmitLoopOverN(); // This emits a loop that loops over the `k` dimension in multiples of - // `k_tiling_factor` as much as possible and then emits a remainder epilogue. + // `tile_size_k` as much as possible and then emits a remainder epilogue. void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, llvm::Value* n_end); - // This emits the inner reduction loop. This inner reduction loop processes - // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension - // and [`n_start`, `n_end`) in the `n` dimension. - void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, - llvm::Value* n_end, VectorSupportLibrary* vsl); + // This emits a loop that loops over the `m` dimension in multiples of + // `tile_size_m` as much as possible and then emits a remainder epilogue. + void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits the inner reduction loop. This inner reduction loop multiplies + // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the + // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size + // [`tile_size_m`, vls->vector_width()] in the result. + void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } @@ -718,7 +766,8 @@ class MatrixMatrixBlockPanelEmitter { int64 min_vectorization_width() const { return config().min_vectorization_width(); } - int64 k_tiling_factor() const { return config().k_tiling_factor(); } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } PrimitiveType scalar_type() const { return config().scalar_type(); } llvm::Value* lhs_; @@ -730,9 +779,15 @@ class MatrixMatrixBlockPanelEmitter { KernelSupportLibrary ksl_; }; -void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); } +void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); } + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. -void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { int64 current_vectorization_width = max_vectorization_width(); int64 n_start = 0; while (n_start != dims().n() && @@ -761,16 +816,30 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, llvm::Value* n_end) { int64 k_start = 0; - int64 k_end = dims().k() - (dims().k() % k_tiling_factor()); + int64 k_end = dims().k() - (dims().k() % tile_size_k()); if (k_end != k_start) { - EmitInnerLoop(k_tiling_factor(), GetInt64(k_start), GetInt64(k_end), - n_start, n_end, vsl); + EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); k_start = k_end; } if (k_start != dims().k()) { - EmitInnerLoop(dims().k() - k_start, GetInt64(k_start), GetInt64(dims().k()), - n_start, n_end, vsl); + EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, + tile_size_m(), GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), + GetInt64(dims().m())); } } @@ -778,11 +847,11 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, // // Let the LHS be: // -// +---+---+---+ -// | a | b | c | . -// +---+---+---+ . -// | | | | . -// +---+---+---+ +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ // .. .. // // and the RHS be: @@ -796,72 +865,77 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, // +----+----+----+----+ . // ...... ...... // -// and let k_tiling_factor be 3 and the vector width (implicitly denoted by -// `vsl`) be 4. +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. // -// Then we +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: // -// 1. broadcast the first row in LHS to 3 vectors of width 4 -// 2. elementwise multiply the RHS rows with these broadcasted vectors -// 3. elementwise add them: +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ // -// +---+---+---+---+ +----+----+----+----+ -// | a | a | a | a | * | p0 | p1 | p2 | p3 | + -// +---+---+---+---+ +----+----+----+----+ // -// +---+---+---+---+ +----+----+----+----+ -// | b | b | b | b | * | q0 | q1 | q2 | q3 | + -// +---+---+---+---+ +----+----+----+----+ +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: // -// +---+---+---+---+ +----+----+----+----+ -// | c | c | c | c | * | r0 | r1 | r2 | r3 | -// +---+---+---+---+ +----+----+----+----+ +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ // -// to get: +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ // -// +----------------+----------------+----------------+----------------+ -// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 | -// +----------------+----------------+----------------+----------------+ +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ // -// which we increment into the appropriate region in the result. -void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( - int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) { - ksl_.For("dot.m", 0, dims().m(), 1, [&](llvm::Value* m_i) { - // This outer loop iterates over all of the M dimension - llvm::Value* result_row_begin = vsl->ComputeOffsetPointer( - result_, /*offset_elements=*/m_i, /*scale=*/dims().n()); - llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer( - lhs_, /*offset_elements=*/m_i, /*scale=*/dims().k()); - - ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) { - // broadcasted_a is the broadcasted set of vectors denoted as , - // etc. in the diagram. - std::vector broadcasted_a; - broadcasted_a.reserve(k_tiling_factor); - for (int i = 0; i < k_tiling_factor; i++) { - broadcasted_a.push_back(vsl->LoadBroadcast( - lhs_row_begin, ir_builder_->CreateAdd(GetInt64(i), k_i))); - } - - // rhs_loader will be used to load the tile off of the RHS, denoted as - // <, ...> in the diagram. - MemoryTile rhs_loader(vsl, ir_builder_, rhs_, dims().n(), k_i, - k_tiling_factor); +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + + ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i, + tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); ksl_.For( "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - // This loop iterates over the N dimension. It loads the tile from - // RHS, does the FMA resulting in the - // in the diagram and increments - // the result. - std::vector tile = rhs_loader.LoadTile(n_i); - llvm::Value* result_accumulator = - vsl->LoadVector(result_row_begin, n_i); - for (int i = 0; i < tile.size(); i++) { - result_accumulator = - vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator); + std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_memory_tile.LoadTile(n_i); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } } - vsl->StoreVector(result_accumulator, result_row_begin, n_i); + result_memory_tile.StoreTile(result_tile, n_i); }); }); }); @@ -955,7 +1029,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( /*scalar_type=*/primitive_type, MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, /*max_vectorization_width=*/8, /*min_vectorization_width=*/4, - /*k_tiling_factor=*/8); + /*tile_size_m=*/3, /*tile_size_k=*/8); const bool enable_fast_math = hlo_module_config_.debug_options().xla_enable_fast_math(); -- 2.7.4