Cache generated LLVM IR for GEBP
authorSanjoy Das <sanjoy@google.com>
Thu, 24 May 2018 00:49:42 +0000 (17:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 00:52:23 +0000 (17:52 -0700)
After this change all generated GEBPs with the same shape will share a single
llvm::Function.

This is NFC for any actual workloads because the GEBP emitter isn't exercised by
normal code-paths yet.

PiperOrigin-RevId: 197820606

tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc

index 5158779..3aa436b 100644 (file)
@@ -610,16 +610,21 @@ class MatrixMatrixBlockPanelEmitter {
     int64 k() const { return k_; }
     int64 n() const { return n_; }
 
+    string ToString() const {
+      return tensorflow::strings::StrCat(m(), "x", k(), "x", n());
+    }
+
    private:
     const int64 m_;
     const int64 k_;
     const int64 n_;
   };
 
-  // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
-  // `lhs` with `rhs` and stores the result in `result`.
+  // Represents the configuration of the GEBP emitter.  The LLVM IR emitted by
+  // the emitter, modulo the LLVM values holding the input and output buffers,
+  // must be a function of the instance of `Config` passed to it.
   //
-  // `m`, `k` and `n` are the matrix multiplication dimensions.
+  // `dims` holds the matrix multiplication dimensions.
   //
   // `max_vectorization_width` is the maximum vector width (i.e. the width of
   // the largest vector register we will use).  This can be larger than the
@@ -630,27 +635,54 @@ class MatrixMatrixBlockPanelEmitter {
   //
   // `k_tiling_factor` is the number of elements along the reduction dimensions
   // that we will attempt to process at once.
-  explicit MatrixMatrixBlockPanelEmitter(
-      llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims,
-      int max_vectorization_width, int min_vectorization_width,
-      int k_tiling_factor, const TargetMachineFeatures& target_machine_features,
-      llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type)
+  class Config {
+   public:
+    explicit Config(PrimitiveType scalar_type, Dimensions dims,
+                    int64 max_vectorization_width,
+                    int64 min_vectorization_width, int64 k_tiling_factor)
+        : scalar_type_(scalar_type),
+          dims_(dims),
+          max_vectorization_width_(max_vectorization_width),
+          min_vectorization_width_(min_vectorization_width),
+          k_tiling_factor_(k_tiling_factor) {}
+
+    string GetCacheKey() const {
+      return tensorflow::strings::StrCat(
+          "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
+          "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
+          k_tiling_factor());
+    }
+
+    PrimitiveType scalar_type() const { return scalar_type_; }
+    Dimensions dims() const { return dims_; }
+    int64 max_vectorization_width() const { return max_vectorization_width_; }
+    int64 min_vectorization_width() const { return min_vectorization_width_; }
+    int64 k_tiling_factor() const { return k_tiling_factor_; }
+
+   private:
+    PrimitiveType scalar_type_;
+    Dimensions dims_;
+    int64 max_vectorization_width_;
+    int64 min_vectorization_width_;
+    int64 k_tiling_factor_;
+  };
+
+  // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
+  // `lhs` with `rhs` and stores the result in `result`.
+  explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
+                                         llvm::Value* rhs, llvm::Value* result,
+                                         llvm::IRBuilder<>* ir_builder)
       : lhs_(lhs),
         rhs_(rhs),
         result_(result),
-        dims_(dims),
-        max_vectorization_width_(max_vectorization_width),
-        min_vectorization_width_(min_vectorization_width),
-        k_tiling_factor_(k_tiling_factor),
-        target_machine_features_(target_machine_features),
+        config_(config),
         ir_builder_(ir_builder),
-        primitive_type_(primitive_type),
         ksl_(ir_builder_) {
-    CHECK(max_vectorization_width > 0 &&
-          IsPowerOfTwo(static_cast<uint64>(max_vectorization_width)));
-    CHECK(min_vectorization_width > 0 &&
-          IsPowerOfTwo(static_cast<uint64>(min_vectorization_width)));
-    CHECK_GT(k_tiling_factor, 0);
+    CHECK(max_vectorization_width() > 0 &&
+          IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
+    CHECK(min_vectorization_width() > 0 &&
+          IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
+    CHECK_GT(k_tiling_factor(), 0);
   }
 
   void Emit();
@@ -677,31 +709,37 @@ class MatrixMatrixBlockPanelEmitter {
 
   llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); }
 
+  Config config() const { return config_; }
+  Dimensions dims() const { return config().dims(); }
+
+  int64 max_vectorization_width() const {
+    return config().max_vectorization_width();
+  }
+  int64 min_vectorization_width() const {
+    return config().min_vectorization_width();
+  }
+  int64 k_tiling_factor() const { return config().k_tiling_factor(); }
+  PrimitiveType scalar_type() const { return config().scalar_type(); }
+
   llvm::Value* lhs_;
   llvm::Value* rhs_;
   llvm::Value* result_;
-  Dimensions dims_;
-
-  int64 max_vectorization_width_;
-  int64 min_vectorization_width_;
-  int64 k_tiling_factor_;
+  Config config_;
 
-  const TargetMachineFeatures& target_machine_features_;
   llvm::IRBuilder<>* ir_builder_;
-  PrimitiveType primitive_type_;
   KernelSupportLibrary ksl_;
 };
 
 void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
 
 void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
-  int64 current_vectorization_width = max_vectorization_width_;
+  int64 current_vectorization_width = max_vectorization_width();
   int64 n_start = 0;
-  while (n_start != dims_.n() &&
-         current_vectorization_width >= min_vectorization_width_) {
-    int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width);
+  while (n_start != dims().n() &&
+         current_vectorization_width >= min_vectorization_width()) {
+    int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
     if (n_start != n_end) {
-      VectorSupportLibrary vsl(primitive_type_, current_vectorization_width,
+      VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
                                ir_builder_, "gebp");
       EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end));
       n_start = n_end;
@@ -709,9 +747,9 @@ void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
     current_vectorization_width /= 2;
   }
 
-  if (n_start != dims_.n()) {
-    VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp");
-    ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) {
+  if (n_start != dims().n()) {
+    VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp");
+    ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
       llvm::Value* n_i_next =
           ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
       EmitLoopOverK(&vsl, n_i, n_i_next);
@@ -723,15 +761,15 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
                                                   llvm::Value* n_start,
                                                   llvm::Value* n_end) {
   int64 k_start = 0;
-  int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_);
+  int64 k_end = dims().k() - (dims().k() % k_tiling_factor());
   if (k_end != k_start) {
-    EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start,
-                  n_end, vsl);
+    EmitInnerLoop(k_tiling_factor(), getInt64(k_start), getInt64(k_end),
+                  n_start, n_end, vsl);
     k_start = k_end;
   }
 
-  if (k_start != dims_.k()) {
-    EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()),
+  if (k_start != dims().k()) {
+    EmitInnerLoop(dims().k() - k_start, getInt64(k_start), getInt64(dims().k()),
                   n_start, n_end, vsl);
   }
 }
@@ -789,12 +827,12 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
 void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
     int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
     llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
-  ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) {
+  ksl_.For("dot.m", 0, dims().m(), 1, [&](llvm::Value* m_i) {
     // This outer loop iterates over all of the M dimension
     llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
-        result_, /*offset_elements=*/m_i, /*scale=*/dims_.n());
+        result_, /*offset_elements=*/m_i, /*scale=*/dims().n());
     llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer(
-        lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k());
+        lhs_, /*offset_elements=*/m_i, /*scale=*/dims().k());
 
     ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
       // broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
@@ -808,7 +846,7 @@ void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
 
       // rhs_loader will be used to load the tile off of the RHS, denoted as
       // <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
-      TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i,
+      TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims().n(), k_i,
                             k_tiling_factor);
       ksl_.For(
           "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
@@ -913,14 +951,28 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
       target, ir_builder_->getInt8(0), size_bytes,
       target_machine_features_.minimum_alignment_for_allocation(size_bytes));
 
-  MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k,
-                                                      /*n=*/n);
-  MatrixMatrixBlockPanelEmitter gebp_emitter(
-      /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims,
+  MatrixMatrixBlockPanelEmitter::Config config(
+      /*scalar_type=*/primitive_type,
+      MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
       /*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
-      /*k_tiling_factor=*/8, target_machine_features_, ir_builder_,
-      primitive_type);
-  gebp_emitter.Emit();
+      /*k_tiling_factor=*/8);
+
+  const bool enable_fast_math =
+      hlo_module_config_.debug_options().xla_enable_fast_math();
+  const bool optimize_for_size =
+      options::OptimizeForSizeRequested(hlo_module_config_);
+
+  KernelSupportLibrary::EmitAndCallOutlinedKernel(
+      /*enable_fast_math=*/enable_fast_math,
+      /*optimize_for_size=*/optimize_for_size, ir_builder_,
+      config.GetCacheKey(), lhs, rhs, target,
+      [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
+        MatrixMatrixBlockPanelEmitter gebp_emitter(
+            config, /*lhs=*/lhs, /*rhs=*/rhs,
+            /*result=*/target, ir_builder_);
+        gebp_emitter.Emit();
+      });
+
   return true;
 }