From 335d52c17644f417bc53abe4ef87ead9de01ad6d Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 23 May 2018 17:49:42 -0700 Subject: [PATCH] Cache generated LLVM IR for GEBP After this change all generated GEBPs with the same shape will share a single llvm::Function. This is NFC for any actual workloads because the GEBP emitter isn't exercised by normal code-paths yet. PiperOrigin-RevId: 197820606 --- .../compiler/xla/service/cpu/dot_op_emitter.cc | 152 ++++++++++++++------- 1 file changed, 102 insertions(+), 50 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 5158779..3aa436b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -610,16 +610,21 @@ class MatrixMatrixBlockPanelEmitter { int64 k() const { return k_; } int64 n() const { return n_; } + string ToString() const { + return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); + } + private: const int64 m_; const int64 k_; const int64 n_; }; - // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies - // `lhs` with `rhs` and stores the result in `result`. + // Represents the configuration of the GEBP emitter. The LLVM IR emitted by + // the emitter, modulo the LLVM values holding the input and output buffers, + // must be a function of the instance of `Config` passed to it. // - // `m`, `k` and `n` are the matrix multiplication dimensions. + // `dims` holds the matrix multiplication dimensions. // // `max_vectorization_width` is the maximum vector width (i.e. the width of // the largest vector register we will use). This can be larger than the @@ -630,27 +635,54 @@ class MatrixMatrixBlockPanelEmitter { // // `k_tiling_factor` is the number of elements along the reduction dimensions // that we will attempt to process at once. - explicit MatrixMatrixBlockPanelEmitter( - llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims, - int max_vectorization_width, int min_vectorization_width, - int k_tiling_factor, const TargetMachineFeatures& target_machine_features, - llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type) + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, + int64 min_vectorization_width, int64 k_tiling_factor) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + min_vectorization_width_(min_vectorization_width), + k_tiling_factor_(k_tiling_factor) {} + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), + "_", max_vectorization_width(), "_", min_vectorization_width(), "_", + k_tiling_factor()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + int64 k_tiling_factor() const { return k_tiling_factor_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 min_vectorization_width_; + int64 k_tiling_factor_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) : lhs_(lhs), rhs_(rhs), result_(result), - dims_(dims), - max_vectorization_width_(max_vectorization_width), - min_vectorization_width_(min_vectorization_width), - k_tiling_factor_(k_tiling_factor), - target_machine_features_(target_machine_features), + config_(config), ir_builder_(ir_builder), - primitive_type_(primitive_type), ksl_(ir_builder_) { - CHECK(max_vectorization_width > 0 && - IsPowerOfTwo(static_cast(max_vectorization_width))); - CHECK(min_vectorization_width > 0 && - IsPowerOfTwo(static_cast(min_vectorization_width))); - CHECK_GT(k_tiling_factor, 0); + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GT(k_tiling_factor(), 0); } void Emit(); @@ -677,31 +709,37 @@ class MatrixMatrixBlockPanelEmitter { llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); } + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 k_tiling_factor() const { return config().k_tiling_factor(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* result_; - Dimensions dims_; - - int64 max_vectorization_width_; - int64 min_vectorization_width_; - int64 k_tiling_factor_; + Config config_; - const TargetMachineFeatures& target_machine_features_; llvm::IRBuilder<>* ir_builder_; - PrimitiveType primitive_type_; KernelSupportLibrary ksl_; }; void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); } void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { - int64 current_vectorization_width = max_vectorization_width_; + int64 current_vectorization_width = max_vectorization_width(); int64 n_start = 0; - while (n_start != dims_.n() && - current_vectorization_width >= min_vectorization_width_) { - int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width); + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { - VectorSupportLibrary vsl(primitive_type_, current_vectorization_width, + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, ir_builder_, "gebp"); EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end)); n_start = n_end; @@ -709,9 +747,9 @@ void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { current_vectorization_width /= 2; } - if (n_start != dims_.n()) { - VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp"); - ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) { + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); EmitLoopOverK(&vsl, n_i, n_i_next); @@ -723,15 +761,15 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, llvm::Value* n_end) { int64 k_start = 0; - int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_); + int64 k_end = dims().k() - (dims().k() % k_tiling_factor()); if (k_end != k_start) { - EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start, - n_end, vsl); + EmitInnerLoop(k_tiling_factor(), getInt64(k_start), getInt64(k_end), + n_start, n_end, vsl); k_start = k_end; } - if (k_start != dims_.k()) { - EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()), + if (k_start != dims().k()) { + EmitInnerLoop(dims().k() - k_start, getInt64(k_start), getInt64(dims().k()), n_start, n_end, vsl); } } @@ -789,12 +827,12 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) { - ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) { + ksl_.For("dot.m", 0, dims().m(), 1, [&](llvm::Value* m_i) { // This outer loop iterates over all of the M dimension llvm::Value* result_row_begin = vsl->ComputeOffsetPointer( - result_, /*offset_elements=*/m_i, /*scale=*/dims_.n()); + result_, /*offset_elements=*/m_i, /*scale=*/dims().n()); llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer( - lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k()); + lhs_, /*offset_elements=*/m_i, /*scale=*/dims().k()); ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) { // broadcasted_a is the broadcasted set of vectors denoted as , @@ -808,7 +846,7 @@ void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( // rhs_loader will be used to load the tile off of the RHS, denoted as // <, ...> in the diagram. - TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i, + TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims().n(), k_i, k_tiling_factor); ksl_.For( "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { @@ -913,14 +951,28 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( target, ir_builder_->getInt8(0), size_bytes, target_machine_features_.minimum_alignment_for_allocation(size_bytes)); - MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k, - /*n=*/n); - MatrixMatrixBlockPanelEmitter gebp_emitter( - /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims, + MatrixMatrixBlockPanelEmitter::Config config( + /*scalar_type=*/primitive_type, + MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, /*max_vectorization_width=*/8, /*min_vectorization_width=*/4, - /*k_tiling_factor=*/8, target_machine_features_, ir_builder_, - primitive_type); - gebp_emitter.Emit(); + /*k_tiling_factor=*/8); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs, rhs, target, + [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { + MatrixMatrixBlockPanelEmitter gebp_emitter( + config, /*lhs=*/lhs, /*rhs=*/rhs, + /*result=*/target, ir_builder_); + gebp_emitter.Emit(); + }); + return true; } -- 2.7.4