MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
llvm::Value* matrix, int64 matrix_size_along_minor_dim,
llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
- : vsl_(vsl) {
+ : vsl_(vsl), ir_builder_(ir_builder) {
pointers_.reserve(tile_size_along_major_dim);
for (int64 i = 0; i < tile_size_along_major_dim; i++) {
llvm::Value* total_offset = ir_builder->CreateMul(
}
}
- // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
- // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
- // minor dimension.
+ // Load a tile consisting of `tile_size_along_major_dim` vectors from position
+ // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
+ //
+ // Note: `major_dim_offset` is a parameter to the constructor.
std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
std::vector<llvm::Value*> result;
result.reserve(pointers_.size());
return result;
}
+ // Stores `tile` to position {major: `major_dim_offset`, minor:
+ // `minor_dim_offset`}.
+ //
+ // Note: `major_dim_offset` is a parameter to the constructor.
+ void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
+ llvm::Value* minor_dim_offset) const {
+ CHECK_EQ(tile.size(), pointers_.size());
+ for (int64 i = 0; i < pointers_.size(); i++) {
+ vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
+ }
+ }
+
+ // Loads a tile of size [`tile_size_along_major_dim`,
+ // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
+ // minor: `minor_dim_offset`} and then broadcasts each element into a vector
+ // of size vsl_.vector_size(). The (i,j)'th element of the return value is
+ // the (i,j)'th element in the tile broadcasted into an LLVM vector.
+ //
+ // Note: `major_dim_offset` is a parameter to the constructor.
+ std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
+ llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const {
+ std::vector<std::vector<llvm::Value*>> result;
+ result.resize(pointers_.size());
+ for (int64 i = 0; i < pointers_.size(); i++) {
+ for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
+ result[i].push_back(vsl_->LoadBroadcast(
+ pointers_[i], ir_builder_->CreateAdd(minor_dim_offset,
+ ir_builder_->getInt64(j))));
+ }
+ }
+ return result;
+ }
+
private:
VectorSupportLibrary* vsl_;
+ llvm::IRBuilder<>* ir_builder_;
std::vector<llvm::Value*> pointers_;
};
// `min_vectorization_width` is the smallest vector width the emitter will use
// -- below that it will devolve to using a scalar loop.
//
- // `k_tiling_factor` is the number of elements along the reduction dimensions
- // that we will attempt to process at once.
+ // The innermost reduction loop executes the matrix multiply in tiles of size
+ // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
+ // <vectorization width>] in the RHS.
class Config {
public:
explicit Config(PrimitiveType scalar_type, Dimensions dims,
int64 max_vectorization_width,
- int64 min_vectorization_width, int64 k_tiling_factor)
+ int64 min_vectorization_width, int64 tile_size_m,
+ int64 tile_size_k)
: scalar_type_(scalar_type),
dims_(dims),
max_vectorization_width_(max_vectorization_width),
min_vectorization_width_(min_vectorization_width),
- k_tiling_factor_(k_tiling_factor) {}
+ tile_size_m_(tile_size_m),
+ tile_size_k_(tile_size_k) {}
string GetCacheKey() const {
return tensorflow::strings::StrCat(
"gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
"_", max_vectorization_width(), "_", min_vectorization_width(), "_",
- k_tiling_factor());
+ tile_size_m(), "_", tile_size_k());
}
PrimitiveType scalar_type() const { return scalar_type_; }
Dimensions dims() const { return dims_; }
int64 max_vectorization_width() const { return max_vectorization_width_; }
int64 min_vectorization_width() const { return min_vectorization_width_; }
- int64 k_tiling_factor() const { return k_tiling_factor_; }
+
+ int64 tile_size_m() const { return tile_size_m_; }
+ int64 tile_size_k() const { return tile_size_k_; }
private:
PrimitiveType scalar_type_;
Dimensions dims_;
int64 max_vectorization_width_;
int64 min_vectorization_width_;
- int64 k_tiling_factor_;
+ int64 tile_size_m_;
+ int64 tile_size_k_;
};
// Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
CHECK(min_vectorization_width() > 0 &&
IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
- CHECK_GT(k_tiling_factor(), 0);
+ CHECK_GT(tile_size_k(), 0);
}
void Emit();
private:
- // We can only iterate the `n` dimension for an extent that is divisible by
- // the vectorization width. So we emit an outer loop that first processes the
- // largest extent in `n` that is divisible by max_vectorization_width, then
- // the largest remaining extent that is divisible by max_vectorization_width /
- // 2 etc. This function emits that outermost loop.
- void EmitChunkedLoopOverN();
+ // This emits a loop that loops over the `n` dimension in multiples of
+ // `max_vectorization_width` as much as possible and then emits a remainder
+ // epilogue.
+ void EmitLoopOverN();
// This emits a loop that loops over the `k` dimension in multiples of
- // `k_tiling_factor` as much as possible and then emits a remainder epilogue.
+ // `tile_size_k` as much as possible and then emits a remainder epilogue.
void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
llvm::Value* n_end);
- // This emits the inner reduction loop. This inner reduction loop processes
- // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension
- // and [`n_start`, `n_end`) in the `n` dimension.
- void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start,
- llvm::Value* k_end, llvm::Value* n_start,
- llvm::Value* n_end, VectorSupportLibrary* vsl);
+ // This emits a loop that loops over the `m` dimension in multiples of
+ // `tile_size_m` as much as possible and then emits a remainder epilogue.
+ void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end);
+
+ // This emits the inner reduction loop. This inner reduction loop multiplies
+ // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the
+ // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size
+ // [`tile_size_m`, vls->vector_width()] in the result.
+ void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end,
+ int64 tile_size_m, llvm::Value* m_start,
+ llvm::Value* m_end);
llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); }
int64 min_vectorization_width() const {
return config().min_vectorization_width();
}
- int64 k_tiling_factor() const { return config().k_tiling_factor(); }
+ int64 tile_size_m() const { return config().tile_size_m(); }
+ int64 tile_size_k() const { return config().tile_size_k(); }
PrimitiveType scalar_type() const { return config().scalar_type(); }
llvm::Value* lhs_;
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
+void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); }
+
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
+ // We can only iterate the `n` dimension for an extent that is divisible by
+ // the vectorization width. So we emit an outer loop that first processes the
+ // largest extent in `n` that is divisible by max_vectorization_width, then
+ // the largest remaining extent that is divisible by max_vectorization_width /
+ // 2 etc.
-void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
int64 current_vectorization_width = max_vectorization_width();
int64 n_start = 0;
while (n_start != dims().n() &&
llvm::Value* n_start,
llvm::Value* n_end) {
int64 k_start = 0;
- int64 k_end = dims().k() - (dims().k() % k_tiling_factor());
+ int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
- EmitInnerLoop(k_tiling_factor(), GetInt64(k_start), GetInt64(k_end),
- n_start, n_end, vsl);
+ EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
+ n_start, n_end);
k_start = k_end;
}
if (k_start != dims().k()) {
- EmitInnerLoop(dims().k() - k_start, GetInt64(k_start), GetInt64(dims().k()),
- n_start, n_end, vsl);
+ EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start),
+ GetInt64(dims().k()), n_start, n_end);
+ }
+}
+
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
+ VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
+ llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
+ const int64 m_end = dims().m() - dims().m() % tile_size_m();
+ EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+ tile_size_m(), GetInt64(0), GetInt64(m_end));
+
+ if (m_end != dims().m()) {
+ EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+ dims().m() - m_end, GetInt64(m_end),
+ GetInt64(dims().m()));
}
}
//
// Let the LHS be:
//
-// +---+---+---+
-// | a | b | c | .
-// +---+---+---+ .
-// | | | | .
-// +---+---+---+
+// +----+----+----+
+// | a0 | b0 | c0 | .
+// +----+----+----+ .
+// | a1 | b1 | c1 | .
+// +----+----+----+
// .. ..
//
// and the RHS be:
// +----+----+----+----+ .
// ...... ......
//
-// and let k_tiling_factor be 3 and the vector width (implicitly denoted by
-// `vsl`) be 4.
+// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
+// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4]
+// matrix that we can increment the result matrix by.
//
-// Then we
+// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
+// 3 array, L, of dimension [2,3,4]:
//
-// 1. broadcast the first row in LHS to 3 vectors of width 4
-// 2. elementwise multiply the RHS rows with these broadcasted vectors
-// 3. elementwise add them:
+// L[0,_,_] * L[1,_,_]
+// *
+// +----+----+----+----+ * +----+----+----+----+
+// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 |
+// +----+----+----+----+ * +----+----+----+----+
+// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 |
+// +----+----+----+----+ * +----+----+----+----+
+// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 |
+// +----+----+----+----+ * +----+----+----+----+
//
-// +---+---+---+---+ +----+----+----+----+
-// | a | a | a | a | * | p0 | p1 | p2 | p3 | +
-// +---+---+---+---+ +----+----+----+----+
//
-// +---+---+---+---+ +----+----+----+----+
-// | b | b | b | b | * | q0 | q1 | q2 | q3 | +
-// +---+---+---+---+ +----+----+----+----+
+// Then we FMA L[0,_,_] with the RHS to get the first row of the result and
+// L[1,_,_] with the RHS to get the second row of the result. For example,
+// L[0,_,_] is computed as:
//
-// +---+---+---+---+ +----+----+----+----+
-// | c | c | c | c | * | r0 | r1 | r2 | r3 |
-// +---+---+---+---+ +----+----+----+----+
+// +----+----+----+----+ +----+----+----+----+
+// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | +
+// +----+----+----+----+ +----+----+----+----+
//
-// to get:
+// +----+----+----+----+ +----+----+----+----+
+// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | +
+// +----+----+----+----+ +----+----+----+----+
//
-// +----------------+----------------+----------------+----------------+
-// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 |
-// +----------------+----------------+----------------+----------------+
+// +----+----+----+----+ +----+----+----+----+
+// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
+// +----+----+----+----+ +----+----+----+----+
//
-// which we increment into the appropriate region in the result.
-void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
- int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
- ksl_.For("dot.m", 0, dims().m(), 1, [&](llvm::Value* m_i) {
- // This outer loop iterates over all of the M dimension
- llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
- result_, /*offset_elements=*/m_i, /*scale=*/dims().n());
- llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer(
- lhs_, /*offset_elements=*/m_i, /*scale=*/dims().k());
-
- ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
- // broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
- // <b,b,b,b> etc. in the diagram.
- std::vector<llvm::Value*> broadcasted_a;
- broadcasted_a.reserve(k_tiling_factor);
- for (int i = 0; i < k_tiling_factor; i++) {
- broadcasted_a.push_back(vsl->LoadBroadcast(
- lhs_row_begin, ir_builder_->CreateAdd(GetInt64(i), k_i)));
- }
-
- // rhs_loader will be used to load the tile off of the RHS, denoted as
- // <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
- MemoryTile rhs_loader(vsl, ir_builder_, rhs_, dims().n(), k_i,
- k_tiling_factor);
+// to get:
+//
+// +-------------------+-------------------+-------------------+---------
+// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
+// +-------------------+-------------------+-------------------+---------
+void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop(
+ VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
+ llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
+ int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
+ ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
+ MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_,
+ /*matrix_size_along_minor_dim=*/dims().n(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+ MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/dims().k(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+
+ ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
+ MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i,
+ tile_size_k);
+ std::vector<std::vector<llvm::Value*>> lhs_tile =
+ lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
ksl_.For(
"dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
- // This loop iterates over the N dimension. It loads the tile from
- // RHS, does the FMA resulting in the
- // <a*p0+b*q0+c*r0,a*p1+b*q1+c*r1,...> in the diagram and increments
- // the result.
- std::vector<llvm::Value*> tile = rhs_loader.LoadTile(n_i);
- llvm::Value* result_accumulator =
- vsl->LoadVector(result_row_begin, n_i);
- for (int i = 0; i < tile.size(); i++) {
- result_accumulator =
- vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator);
+ std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
+ std::vector<llvm::Value*> result_tile =
+ result_memory_tile.LoadTile(n_i);
+ for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
+ for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
+ result_tile[r_m_i] =
+ vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
+ result_tile[r_m_i]);
+ }
}
- vsl->StoreVector(result_accumulator, result_row_begin, n_i);
+ result_memory_tile.StoreTile(result_tile, n_i);
});
});
});
/*scalar_type=*/primitive_type,
MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
/*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
- /*k_tiling_factor=*/8);
+ /*tile_size_m=*/3, /*tile_size_k=*/8);
const bool enable_fast_math =
hlo_module_config_.debug_options().xla_enable_fast_math();