std::vector<llvm::Value*> pointers_;
};
+// The base class for the classes representing the GEMV emitter configurations.
+//
+// The IR emitted (modulo the LLVM values representing the input and output
+// buffers) by the row major and column major GEMV emitters should be a function
+// of their configuration. This is important because their configuration is
+// used as a key to cache the generated IR.
+class GemvConfig {
+ public:
+ // Mixin for convenience.
+ template <typename T>
+ struct User {
+ public:
+ PrimitiveType scalar_type() const {
+ return derived().config().scalar_type();
+ }
+ int64 tile_rows() const { return derived().config().tile_rows(); }
+ int64 tile_cols() const { return derived().config().tile_cols(); }
+ int64 m() const { return derived().config().m(); }
+ int64 k() const { return derived().config().k(); }
+ int64 has_addend() const { return derived().config().has_addend(); }
+
+ private:
+ const T& derived() const { return *static_cast<const T*>(this); }
+ };
+
+ PrimitiveType scalar_type() const { return scalar_type_; }
+ int64 tile_rows() const { return tile_rows_; }
+ int64 tile_cols() const { return tile_cols_; }
+ int64 m() const { return m_; }
+ int64 k() const { return k_; }
+ bool has_addend() const { return has_addend_; }
+
+ string GetCacheKey() const {
+ return tensorflow::strings::StrCat(
+ name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_",
+ tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : "");
+ }
+
+ protected:
+ explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows,
+ int64 tile_cols, int64 m, int64 k, bool has_addend)
+ : name_(std::move(name)),
+ scalar_type_(scalar_type),
+ tile_rows_(tile_rows),
+ tile_cols_(tile_cols),
+ m_(m),
+ k_(k),
+ has_addend_(has_addend) {}
+
+ private:
+ string name_;
+ PrimitiveType scalar_type_;
+ int64 tile_rows_;
+ int64 tile_cols_;
+ int64 m_;
+ int64 k_;
+ bool has_addend_;
+};
+
// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
// layout of the vector does not matter). This implementation uses a tiling
// scheme to improve performance.
// TODO(sanjoy): We should investigate if using gather loads and scatter stores
// can be used here have the same inner loop for both column-major and row-major
// matrix-vector products.
-class ColumnMajorMatrixVectorProductEmitter {
+class ColumnMajorMatrixVectorProductEmitter
+ : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
public:
- ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
- int64 tile_rows, int64 tile_cols,
- int64 m, int64 k, llvm::Value* lhs,
+ class Config : public GemvConfig {
+ public:
+ explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
+ int64 m, int64 k, bool has_addend)
+ : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
+ /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
+ /*k=*/k, /*has_addend=*/has_addend) {}
+ };
+
+ ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result,
llvm::IRBuilder<>* ir_builder)
- : scalar_type_(scalar_type),
- tile_rows_(tile_rows),
- tile_cols_(tile_cols),
- m_(m),
- k_(k),
+ : config_(config),
lhs_(lhs),
rhs_(rhs),
addend_(addend),
result_(result),
ir_builder_(ir_builder),
ksl_(ir_builder_),
- vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
- CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
+ vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(),
+ ir_builder_, "") {
+ CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
+ CHECK(!has_addend() || addend != nullptr);
}
void Emit();
+ const Config& config() const { return config_; }
+
private:
void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
bool is_first_column);
TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
- /*matrix_size_along_minor_dim=*/m_,
+ /*matrix_size_along_minor_dim=*/m(),
/*major_dim_offset=*/column_start,
/*tile_size_along_major_dim=*/column_count);
}
void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
bool is_first_tiled_column);
- PrimitiveType scalar_type_;
- int64 tile_rows_;
- int64 tile_cols_;
- int64 m_;
- int64 k_;
+ Config config_;
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* addend_;
void ColumnMajorMatrixVectorProductEmitter::Emit() {
// See the comment on the class declaration for the algorithm used here.
- int64 column_remainder = k_ % tile_cols_;
- int64 column_limit = k_ - column_remainder;
+ int64 column_remainder = k() % tile_cols();
+ int64 column_limit = k() - column_remainder;
ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
+ /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
[&](llvm::Value* column, bool is_first_column) {
- EmitOuterLoopBody(column, tile_cols_, is_first_column);
+ EmitOuterLoopBody(column, tile_cols(), is_first_column);
});
if (column_remainder != 0) {
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
int64 columns, bool is_first_column) {
- int64 row_limit = m_ - (m_ % tile_rows_);
+ int64 row_limit = m() - (m() % tile_rows());
ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
- /*step=*/tile_rows_, [&](llvm::Value* row) {
+ /*step=*/tile_rows(), [&](llvm::Value* row) {
std::vector<llvm::Value*> lhs_tile =
lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
llvm::Value* accumulator =
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
- int64 row_start = m_ - (m_ % tile_rows_);
- if (row_start == m_) {
+ int64 row_start = m() - (m() % tile_rows());
+ if (row_start == m()) {
return;
}
[&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
llvm::Value* total_offset =
- ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
+ ir_builder_->CreateMul(col, ir_builder_->getInt64(m()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
ksl_.For(
- "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
+ "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
/*step=*/1, [&](llvm::Value* scalar_row) {
llvm::Value* product = vsl_.Mul(
vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
//
// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
// epilogue loop to deal with the C,D submatrix.
-class RowMajorMatrixVectorProductEmitter {
+class RowMajorMatrixVectorProductEmitter
+ : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
public:
- RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
- int64 tile_cols, int64 m, int64 k,
- llvm::Value* lhs, llvm::Value* rhs,
- llvm::Value* addend, llvm::Value* result,
+ class Config : public GemvConfig {
+ public:
+ explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
+ int64 m, int64 k, bool has_addend)
+ : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
+ /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
+ /*k=*/k, /*has_addend=*/has_addend) {}
+ };
+
+ RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* addend,
+ llvm::Value* result,
llvm::IRBuilder<>* ir_builder)
- : scalar_type_(scalar_type),
- tile_rows_(tile_rows),
- tile_cols_(tile_cols),
- m_(m),
- k_(k),
+ : config_(config),
lhs_(lhs),
rhs_(rhs),
addend_(addend),
result_(result),
ir_builder_(ir_builder),
ksl_(ir_builder_),
- vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
- CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
+ vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") {
+ CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
+ CHECK(!has_addend() || addend != nullptr);
}
void Emit();
+ const Config& config() const { return config_; }
+
private:
TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
- /*matrix_size_along_minor_dim=*/k_,
+ /*matrix_size_along_minor_dim=*/k(),
/*major_dim_offset=*/row_start,
/*tile_size_along_major_dim=*/row_count);
}
void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
std::vector<ScalarVariable>* scalar_accumulators);
- PrimitiveType scalar_type_;
- int64 tile_rows_;
- int64 tile_cols_;
- int64 m_;
- int64 k_;
+ Config config_;
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* addend_;
void RowMajorMatrixVectorProductEmitter::Emit() {
// See the comment on the class declaration for the algorithm used here.
- int64 row_remainder = m_ % tile_rows_;
- int64 row_limit = m_ - row_remainder;
+ int64 row_remainder = m() % tile_rows();
+ int64 row_limit = m() - row_remainder;
ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
- [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
+ /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
+ [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
if (row_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
TileLoader* lhs_tile_loader, int64 rows,
std::vector<VectorVariable>* vector_accumulators) {
- int64 column_limit = k_ - (k_ % tile_cols_);
+ int64 column_limit = k() - (k() % tile_cols());
ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
- /*step=*/tile_cols_, [&](llvm::Value* col) {
+ /*step=*/tile_cols(), [&](llvm::Value* col) {
std::vector<llvm::Value*> lhs_tile =
lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* current_tile_row, int64 rows,
std::vector<ScalarVariable>* scalar_accumulators) {
- int64 column_start = k_ - (k_ % tile_cols_);
- if (column_start == k_) {
+ int64 column_start = k() - (k() % tile_cols());
+ if (column_start == k()) {
return;
}
for (int r = 0; r < rows; r++) {
llvm::Value* total_offset = ir_builder_->CreateMul(
ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
- ir_builder_->getInt64(k_));
+ ir_builder_->getInt64(k()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
- ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
+ ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
/*step=*/1, [&](llvm::Value* scalar_col) {
llvm::Value* product =
vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
if (is_column_major_matrix_vector) {
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
<< " and k = " << k;
- int64 tile_rows = vector_register_element_size;
- int64 tile_cols = tiling_factor;
-
- string kernel_name = tensorflow::strings::StrCat(
- "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
- "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
+ ColumnMajorMatrixVectorProductEmitter::Config config(
+ /*scalar_type=*/primitive_type,
+ /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
+ /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
- lhs_op, rhs_op,
+ /*optimize_for_size=*/optimize_for_size, ir_builder_,
+ config.GetCacheKey(), lhs_op, rhs_op,
addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
- [this, tile_rows, tile_cols, m, k, primitive_type](
- llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
- llvm::Value* result_op) {
+ [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
+ llvm::Value* addend_op, llvm::Value* result_op) {
ColumnMajorMatrixVectorProductEmitter emitter(
- primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
- addend_op, result_op, ir_builder_);
+ config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
emitter.Emit();
});
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
- int64 tile_rows = tiling_factor;
- int64 tile_cols = vector_register_element_size;
-
- string kernel_name = tensorflow::strings::StrCat(
- "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
- "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
+ RowMajorMatrixVectorProductEmitter::Config config(
+ /*scalar_type=*/primitive_type,
+ /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size,
+ /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
- lhs_op, rhs_op,
+ /*optimize_for_size=*/optimize_for_size, ir_builder_,
+ config.GetCacheKey(), lhs_op, rhs_op,
addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
- [this, tile_rows, tile_cols, m, k, primitive_type](
- llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
- llvm::Value* result_op) {
+ [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
+ llvm::Value* addend_op, llvm::Value* result_op) {
RowMajorMatrixVectorProductEmitter emitter(
- primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
- addend_op, result_op, ir_builder_);
+ config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
emitter.Emit();
});
}