Use target machine features to determine max vectorization width for GEMM
authorSanjoy Das <sanjoy@google.com>
Tue, 29 May 2018 19:02:57 +0000 (12:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 19:06:26 +0000 (12:06 -0700)
PiperOrigin-RevId: 198434296

tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc

index c704105..d770765 100644 (file)
@@ -1006,8 +1006,6 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
     return false;
   }
 
-  VLOG(2) << "Emitting GEBP kernel in LLVM IR";
-
   llvm::Value* lhs = lhs_array_.GetBasePointer();
   llvm::Value* rhs = rhs_array_.GetBasePointer();
   llvm::Value* target = target_array_.GetBasePointer();
@@ -1025,11 +1023,19 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
       target, ir_builder_->getInt8(0), size_bytes,
       target_machine_features_.minimum_alignment_for_allocation(size_bytes));
 
+  int64 max_vector_width =
+      target_machine_features_.vector_register_num_elements(
+          *ir_builder_->GetInsertBlock()->getParent(), primitive_type);
+
   MatrixMatrixBlockPanelEmitter::Config config(
       /*scalar_type=*/primitive_type,
       MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
-      /*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
-      /*tile_size_m=*/3, /*tile_size_k=*/8);
+      /*max_vectorization_width=*/max_vector_width,
+      /*min_vectorization_width=*/std::min<int64>(4, max_vector_width),
+      /*tile_size_m=*/3, /*tile_size_k=*/5);
+
+  VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
+          << config.GetCacheKey();
 
   const bool enable_fast_math =
       hlo_module_config_.debug_options().xla_enable_fast_math();