Extract out common code and make things safer; NFC
authorSanjoy Das <sanjoy@google.com>
Tue, 22 May 2018 21:55:12 +0000 (14:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 21:58:09 +0000 (14:58 -0700)
RowMajorMatrixVectorProductEmitter and ColumnMajorMatrixVectorProductEmitter
both cache* the generated LLVM IR by keying off the dimensions of the operation,
the primitive type etc.  Before this CL the code computing the cache key lived
separately from the GEMV emitters.  This pattern introduces a risk that the GEMV
emitters will end up with some state not modeled in the cache key, resulting in
a subtle bug.

This CL reduces the risk by escapsulating the cache key generation and the input
configuration to the GEMV emitters in a single class.

 * In the sense that two different dot operations with the same M,K,N will share
   a single LLVM IR function body.

PiperOrigin-RevId: 197628423

tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc

index af69fc3..5158779 100644 (file)
@@ -79,6 +79,65 @@ class TileLoader {
   std::vector<llvm::Value*> pointers_;
 };
 
+// The base class for the classes representing the GEMV emitter configurations.
+//
+// The IR emitted (modulo the LLVM values representing the input and output
+// buffers) by the row major and column major GEMV emitters should be a function
+// of their configuration.  This is important because their configuration is
+// used as a key to cache the generated IR.
+class GemvConfig {
+ public:
+  // Mixin for convenience.
+  template <typename T>
+  struct User {
+   public:
+    PrimitiveType scalar_type() const {
+      return derived().config().scalar_type();
+    }
+    int64 tile_rows() const { return derived().config().tile_rows(); }
+    int64 tile_cols() const { return derived().config().tile_cols(); }
+    int64 m() const { return derived().config().m(); }
+    int64 k() const { return derived().config().k(); }
+    int64 has_addend() const { return derived().config().has_addend(); }
+
+   private:
+    const T& derived() const { return *static_cast<const T*>(this); }
+  };
+
+  PrimitiveType scalar_type() const { return scalar_type_; }
+  int64 tile_rows() const { return tile_rows_; }
+  int64 tile_cols() const { return tile_cols_; }
+  int64 m() const { return m_; }
+  int64 k() const { return k_; }
+  bool has_addend() const { return has_addend_; }
+
+  string GetCacheKey() const {
+    return tensorflow::strings::StrCat(
+        name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_",
+        tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : "");
+  }
+
+ protected:
+  explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows,
+                      int64 tile_cols, int64 m, int64 k, bool has_addend)
+      : name_(std::move(name)),
+        scalar_type_(scalar_type),
+        tile_rows_(tile_rows),
+        tile_cols_(tile_cols),
+        m_(m),
+        k_(k),
+        has_addend_(has_addend) {}
+
+ private:
+  string name_;
+  PrimitiveType scalar_type_;
+  int64 tile_rows_;
+  int64 tile_cols_;
+  int64 m_;
+  int64 k_;
+  bool has_addend_;
+};
+
 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
 // layout of the vector does not matter).  This implementation uses a tiling
 // scheme to improve performance.
@@ -140,38 +199,46 @@ class TileLoader {
 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
 // can be used here have the same inner loop for both column-major and row-major
 // matrix-vector products.
-class ColumnMajorMatrixVectorProductEmitter {
+class ColumnMajorMatrixVectorProductEmitter
+    : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
  public:
-  ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
-                                        int64 tile_rows, int64 tile_cols,
-                                        int64 m, int64 k, llvm::Value* lhs,
+  class Config : public GemvConfig {
+   public:
+    explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
+                    int64 m, int64 k, bool has_addend)
+        : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
+                     /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
+                     /*k=*/k, /*has_addend=*/has_addend) {}
+  };
+
+  ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
                                         llvm::Value* rhs, llvm::Value* addend,
                                         llvm::Value* result,
                                         llvm::IRBuilder<>* ir_builder)
-      : scalar_type_(scalar_type),
-        tile_rows_(tile_rows),
-        tile_cols_(tile_cols),
-        m_(m),
-        k_(k),
+      : config_(config),
         lhs_(lhs),
         rhs_(rhs),
         addend_(addend),
         result_(result),
         ir_builder_(ir_builder),
         ksl_(ir_builder_),
-        vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
-    CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
+        vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(),
+             ir_builder_, "") {
+    CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
+    CHECK(!has_addend() || addend != nullptr);
   }
 
   void Emit();
 
+  const Config& config() const { return config_; }
+
  private:
   void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
                          bool is_first_column);
 
   TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
     return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
-                      /*matrix_size_along_minor_dim=*/m_,
+                      /*matrix_size_along_minor_dim=*/m(),
                       /*major_dim_offset=*/column_start,
                       /*tile_size_along_major_dim=*/column_count);
   }
@@ -195,11 +262,7 @@ class ColumnMajorMatrixVectorProductEmitter {
   void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
                              bool is_first_tiled_column);
 
-  PrimitiveType scalar_type_;
-  int64 tile_rows_;
-  int64 tile_cols_;
-  int64 m_;
-  int64 k_;
+  Config config_;
   llvm::Value* lhs_;
   llvm::Value* rhs_;
   llvm::Value* addend_;
@@ -223,13 +286,13 @@ void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
 
 void ColumnMajorMatrixVectorProductEmitter::Emit() {
   // See the comment on the class declaration for the algorithm used here.
-  int64 column_remainder = k_ % tile_cols_;
-  int64 column_limit = k_ - column_remainder;
+  int64 column_remainder = k() % tile_cols();
+  int64 column_limit = k() - column_remainder;
 
   ksl_.For("dot.outer.tiled",
-           /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
+           /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
            [&](llvm::Value* column, bool is_first_column) {
-             EmitOuterLoopBody(column, tile_cols_, is_first_column);
+             EmitOuterLoopBody(column, tile_cols(), is_first_column);
            });
 
   if (column_remainder != 0) {
@@ -241,10 +304,10 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() {
 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
     TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
     int64 columns, bool is_first_column) {
-  int64 row_limit = m_ - (m_ % tile_rows_);
+  int64 row_limit = m() - (m() % tile_rows());
 
   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
-           /*step=*/tile_rows_, [&](llvm::Value* row) {
+           /*step=*/tile_rows(), [&](llvm::Value* row) {
              std::vector<llvm::Value*> lhs_tile =
                  lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
              llvm::Value* accumulator =
@@ -260,8 +323,8 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
 
 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
     llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
-  int64 row_start = m_ - (m_ % tile_rows_);
-  if (row_start == m_) {
+  int64 row_start = m() - (m() % tile_rows());
+  if (row_start == m()) {
     return;
   }
 
@@ -281,11 +344,11 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
       [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
         llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
         llvm::Value* total_offset =
-            ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
+            ir_builder_->CreateMul(col, ir_builder_->getInt64(m()));
         llvm::Value* lhs_base_pointer =
             vsl_.ComputeOffsetPointer(lhs_, total_offset);
         ksl_.For(
-            "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
+            "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
             /*step=*/1, [&](llvm::Value* scalar_row) {
               llvm::Value* product = vsl_.Mul(
                   vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
@@ -365,34 +428,42 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
 //
 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
 // epilogue loop to deal with the C,D submatrix.
-class RowMajorMatrixVectorProductEmitter {
+class RowMajorMatrixVectorProductEmitter
+    : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
  public:
-  RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
-                                     int64 tile_cols, int64 m, int64 k,
-                                     llvm::Value* lhs, llvm::Value* rhs,
-                                     llvm::Value* addend, llvm::Value* result,
+  class Config : public GemvConfig {
+   public:
+    explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
+                    int64 m, int64 k, bool has_addend)
+        : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
+                     /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
+                     /*k=*/k, /*has_addend=*/has_addend) {}
+  };
+
+  RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
+                                     llvm::Value* rhs, llvm::Value* addend,
+                                     llvm::Value* result,
                                      llvm::IRBuilder<>* ir_builder)
-      : scalar_type_(scalar_type),
-        tile_rows_(tile_rows),
-        tile_cols_(tile_cols),
-        m_(m),
-        k_(k),
+      : config_(config),
         lhs_(lhs),
         rhs_(rhs),
         addend_(addend),
         result_(result),
         ir_builder_(ir_builder),
         ksl_(ir_builder_),
-        vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
-    CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
+        vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") {
+    CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
+    CHECK(!has_addend() || addend != nullptr);
   }
 
   void Emit();
 
+  const Config& config() const { return config_; }
+
  private:
   TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
     return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
-                      /*matrix_size_along_minor_dim=*/k_,
+                      /*matrix_size_along_minor_dim=*/k(),
                       /*major_dim_offset=*/row_start,
                       /*tile_size_along_major_dim=*/row_count);
   }
@@ -405,11 +476,7 @@ class RowMajorMatrixVectorProductEmitter {
   void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
                              std::vector<ScalarVariable>* scalar_accumulators);
 
-  PrimitiveType scalar_type_;
-  int64 tile_rows_;
-  int64 tile_cols_;
-  int64 m_;
-  int64 k_;
+  Config config_;
   llvm::Value* lhs_;
   llvm::Value* rhs_;
   llvm::Value* addend_;
@@ -466,12 +533,12 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
 
 void RowMajorMatrixVectorProductEmitter::Emit() {
   // See the comment on the class declaration for the algorithm used here.
-  int64 row_remainder = m_ % tile_rows_;
-  int64 row_limit = m_ - row_remainder;
+  int64 row_remainder = m() % tile_rows();
+  int64 row_limit = m() - row_remainder;
 
   ksl_.For("dot.outer.tiled",
-           /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
-           [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
+           /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
+           [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
 
   if (row_remainder != 0) {
     EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
@@ -481,10 +548,10 @@ void RowMajorMatrixVectorProductEmitter::Emit() {
 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
     TileLoader* lhs_tile_loader, int64 rows,
     std::vector<VectorVariable>* vector_accumulators) {
-  int64 column_limit = k_ - (k_ % tile_cols_);
+  int64 column_limit = k() - (k() % tile_cols());
 
   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
-           /*step=*/tile_cols_, [&](llvm::Value* col) {
+           /*step=*/tile_cols(), [&](llvm::Value* col) {
              std::vector<llvm::Value*> lhs_tile =
                  lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
              llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
@@ -499,18 +566,18 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
     llvm::Value* current_tile_row, int64 rows,
     std::vector<ScalarVariable>* scalar_accumulators) {
-  int64 column_start = k_ - (k_ % tile_cols_);
-  if (column_start == k_) {
+  int64 column_start = k() - (k() % tile_cols());
+  if (column_start == k()) {
     return;
   }
 
   for (int r = 0; r < rows; r++) {
     llvm::Value* total_offset = ir_builder_->CreateMul(
         ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
-        ir_builder_->getInt64(k_));
+        ir_builder_->getInt64(k()));
     llvm::Value* lhs_base_pointer =
         vsl_.ComputeOffsetPointer(lhs_, total_offset);
-    ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
+    ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
              /*step=*/1, [&](llvm::Value* scalar_col) {
                llvm::Value* product =
                    vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
@@ -942,47 +1009,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
   if (is_column_major_matrix_vector) {
     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
             << " and k = " << k;
-    int64 tile_rows = vector_register_element_size;
-    int64 tile_cols = tiling_factor;
-
-    string kernel_name = tensorflow::strings::StrCat(
-        "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
-        "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
+    ColumnMajorMatrixVectorProductEmitter::Config config(
+        /*scalar_type=*/primitive_type,
+        /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
+        /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr);
 
     KernelSupportLibrary::EmitAndCallOutlinedKernel(
         /*enable_fast_math=*/enable_fast_math,
-        /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
-        lhs_op, rhs_op,
+        /*optimize_for_size=*/optimize_for_size, ir_builder_,
+        config.GetCacheKey(), lhs_op, rhs_op,
         addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
-        [this, tile_rows, tile_cols, m, k, primitive_type](
-            llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
-            llvm::Value* result_op) {
+        [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
+                       llvm::Value* addend_op, llvm::Value* result_op) {
           ColumnMajorMatrixVectorProductEmitter emitter(
-              primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
-              addend_op, result_op, ir_builder_);
+              config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
           emitter.Emit();
         });
   } else {
     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
             << " and k = " << k;
-    int64 tile_rows = tiling_factor;
-    int64 tile_cols = vector_register_element_size;
-
-    string kernel_name = tensorflow::strings::StrCat(
-        "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
-        "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
+    RowMajorMatrixVectorProductEmitter::Config config(
+        /*scalar_type=*/primitive_type,
+        /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size,
+        /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr);
 
     KernelSupportLibrary::EmitAndCallOutlinedKernel(
         /*enable_fast_math=*/enable_fast_math,
-        /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
-        lhs_op, rhs_op,
+        /*optimize_for_size=*/optimize_for_size, ir_builder_,
+        config.GetCacheKey(), lhs_op, rhs_op,
         addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
-        [this, tile_rows, tile_cols, m, k, primitive_type](
-            llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
-            llvm::Value* result_op) {
+        [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
+                       llvm::Value* addend_op, llvm::Value* result_op) {
           RowMajorMatrixVectorProductEmitter emitter(
-              primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
-              addend_op, result_op, ir_builder_);
+              config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
           emitter.Emit();
         });
   }