[XLA:GPU] Extend the CustomCall for cudnn convolutions to represent
authorBixia Zheng <bixia@google.com>
Tue, 13 Feb 2018 00:56:28 +0000 (16:56 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 00:59:58 +0000 (16:59 -0800)
tensor_ops_enabled.

The convolution algorithms returned from the stream executor have a flag
for whether tensor_ops is enabled. This flag is used when running each
algorithm during auto-tunning. However, this flag is not currently represented
in the CustomCall representing the auto-tune result. As a result, the algorithm
may be run differently after auto-tune.

This change adds a constant to the CustomCall for cudnn convolution algorithm
selected by auto-tune, to represent whether tensor_ops is enabled during
auto-tune. This information is used by convolution thunk to ensure that the
algorithm is run with the same flag after auto-tune.

PiperOrigin-RevId: 185458497

tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
tensorflow/compiler/xla/service/gpu/convolution_thunk.h
tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc

index f76f159..15bba49 100644 (file)
@@ -45,7 +45,7 @@ ConvolutionThunk::ConvolutionThunk(
     const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
     const Shape& filter_shape, const Shape& output_shape, const Window& window,
     const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
-    const HloInstruction* hlo)
+    bool tensor_ops_enabled, const HloInstruction* hlo)
     : Thunk(Kind::kConvolution, hlo),
       convolution_kind_(convolution_kind),
       input_buffer_(input_buffer),
@@ -58,7 +58,8 @@ ConvolutionThunk::ConvolutionThunk(
       output_shape_(output_shape),
       window_(window),
       dim_nums_(dim_nums),
-      algorithm_(algorithm) {}
+      algorithm_(algorithm),
+      tensor_ops_enabled_(tensor_ops_enabled) {}
 
 Status ConvolutionThunk::ExecuteOnStream(
     const BufferAllocations& buffer_allocations, se::Stream* stream) {
@@ -72,7 +73,7 @@ Status ConvolutionThunk::ExecuteOnStream(
       buffer_allocations.GetDeviceAddress(scratch_buffer_);
 
   se::dnn::AlgorithmConfig algorithm_config(
-      se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false));
+      se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
 
   TF_RETURN_IF_ERROR(RunCudnnConvolution(
       convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
index ca9ef52..900d9cb 100644 (file)
@@ -59,7 +59,7 @@ class ConvolutionThunk : public Thunk {
                    const Shape& input_shape, const Shape& filter_shape,
                    const Shape& output_shape, const Window& window,
                    const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
-                   const HloInstruction* hlo);
+                   bool tensor_ops_enabled, const HloInstruction* hlo);
 
   ConvolutionThunk(const ConvolutionThunk&) = delete;
   ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -99,6 +99,7 @@ class ConvolutionThunk : public Thunk {
   const Window window_;
   const ConvolutionDimensionNumbers dim_nums_;
   int64 algorithm_;
+  bool tensor_ops_enabled_;
 };
 
 }  // namespace gpu
index 621b2d5..c29aa31 100644 (file)
@@ -172,7 +172,7 @@ string NumBytesToString(int64 bytes) {
 // cache misses and doing extra work.  Overall, caching doesn't seem worth the
 // trouble, but we may want to revisit this if we ever find a model where
 // caching would speed up compilation a lot.
-optional<std::pair<int64, int64>>
+optional<std::tuple<int64, bool, int64>>
 CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
     const Shape& output_shape, const Window& window,
@@ -260,8 +260,9 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
             << AlgorithmToString(best_result.algorithm()) << ", takes "
             << best_result.elapsed_time_in_ms() << "ms, and uses "
             << best_result_bytes_used << "B of scratch memory.";
-    return std::make_pair(best_result.algorithm().algo_id(),
-                          best_result_bytes_used);
+    return std::make_tuple(best_result.algorithm().algo_id(),
+                           best_result.algorithm().tensor_ops_enabled(),
+                           best_result_bytes_used);
   }
 
   LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
@@ -277,19 +278,19 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
   const auto& lhs_shape = instr->operand(0)->shape();
   const auto& rhs_shape = instr->operand(1)->shape();
   const auto& conv_result_shape = instr->shape().tuple_shapes(0);
-  optional<std::pair<int64, int64>> alg_and_scratch_bytes;
+  optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
   if (call_target == kCudnnConvForwardCallTarget) {
-    alg_and_scratch_bytes = PickBestAlgorithm(
+    alg_scratch_and_tc = PickBestAlgorithm(
         CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
         /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
         instr->window(), instr->convolution_dimension_numbers(), instr);
   } else if (call_target == kCudnnConvBackwardInputCallTarget) {
-    alg_and_scratch_bytes = PickBestAlgorithm(
+    alg_scratch_and_tc = PickBestAlgorithm(
         CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
         /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
         instr->convolution_dimension_numbers(), instr);
   } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
-    alg_and_scratch_bytes = PickBestAlgorithm(
+    alg_scratch_and_tc = PickBestAlgorithm(
         CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
         /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
         instr->window(), instr->convolution_dimension_numbers(), instr);
@@ -298,17 +299,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
                << instr->ToString();
   }
 
-  if (!alg_and_scratch_bytes.has_value()) {
+  if (!alg_scratch_and_tc.has_value()) {
     return false;
   }
 
   int64 algorithm;
+  bool tensor_ops_enabled;
   int64 scratch_bytes;
-  std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes;
+
+  std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
 
   VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
           << NumBytesToString(scratch_bytes)
-          << " of scratch memory: " << instr->ToString();
+          << " of scratch memory: " << instr->ToString()
+          << " tensor_ops_enabled: " << tensor_ops_enabled;
 
   // Replace instr with a new CustomCall which has the correct algorithm, and
   // whose output shape has the appropriate amount of scratch memory.
@@ -318,10 +322,15 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
                                  ShapeUtil::MakeShape(U8, {scratch_bytes})});
   HloInstruction* algorithm_hlo = computation->AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateR0<int64>(algorithm)));
+  HloInstruction* tensor_ops_enabled_hlo =
+      computation->AddInstruction(HloInstruction::CreateConstant(
+          Literal::CreateR0<bool>(tensor_ops_enabled)));
+
   HloInstruction* new_call =
       computation->AddInstruction(HloInstruction::CreateCustomCall(
           new_call_shape,
-          {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo},
+          {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo,
+           tensor_ops_enabled_hlo},
           instr->custom_call_target()));
   new_call->set_window(instr->window());
   new_call->set_convolution_dimension_numbers(
index 10e49da..516210e 100644 (file)
@@ -47,7 +47,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
  private:
   StatusOr<bool> RunOnComputation(HloComputation* computation);
   StatusOr<bool> RunOnInstruction(HloInstruction* instr);
-  tensorflow::gtl::optional<std::pair<int64, int64>> PickBestAlgorithm(
+  tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
       CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
       const Shape& output_shape, const Window& window,
       const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
index f5f52cf..81695a6 100644 (file)
@@ -106,6 +106,9 @@ Status RunCudnnConvolution(
     se::ScratchAllocator* scratch_allocator, const Window& window,
     const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
     Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+  VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
+  VLOG(3) << "tensor_ops_enabled: "
+          << algorithm.algorithm().tensor_ops_enabled();
   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
   VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
   VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
index 88bf5a7..916b556 100644 (file)
@@ -79,9 +79,9 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
         TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
       }
     } else if (IsCustomCallToDnnConvolution(*hlo)) {
-      // The last argument to a CUDNN convolution is its algorithm, which must
-      // be an HLO constant -- it shouldn't be copied.
-      for (int64 i = 0; i < hlo->operand_count() - 1; ++i) {
+      // The last two arguments to a CUDNN convolution are two HLO constants for
+      // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied.
+      for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
         TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
       }
     } else if (ImplementedAsLibraryCall(*hlo)) {
index 7ad9680..59455f3 100644 (file)
@@ -63,10 +63,11 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
 // strings.
 //
 // These CustomCalls have window() and convolution_dimension_numbers() set like
-// regular convolution ops.  They have the same LHS and RHS operands, plus one
-// additional int64 operand, representing which cudnn algorithm to run.  This
-// operand must be an HLO constant.  A value of -1 means that the implementation
-// is free to choose the best algorithm it can.
+// regular convolution ops.  They have the same LHS and RHS operands, plus two
+// additional constant operands: an int64 operand for the cudnn algorithm and
+// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
+// algorithm means that the implementation is free to choose the best algorithm
+// it can.
 //
 // These calls output a tuple (conv_result, scratch_memory), where conv_result
 // is the actual result of the convolution, and scratch_memory is temporary
index c81dfbf..7e20af3 100644 (file)
@@ -393,6 +393,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
     CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
     int64 algorithm = algorithm_inst->literal().Get<int64>({});
 
+    const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3);
+    CHECK(tensor_ops_enabled_inst->IsConstant())
+        << tensor_ops_enabled_inst->ToString();
+    bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({});
+
     const auto& target = custom_call->custom_call_target();
     std::unique_ptr<ConvolutionThunk> thunk;
     if (target == kCudnnConvForwardCallTarget) {
@@ -407,7 +412,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
           /*filter_shape=*/rhs_shape,
           /*output_shape=*/conv_result_shape,  //
           custom_call->window(), custom_call->convolution_dimension_numbers(),
-          algorithm, custom_call);
+          algorithm, tensor_ops_enabled, custom_call);
     } else if (target == kCudnnConvBackwardInputCallTarget) {
       thunk = MakeUnique<ConvolutionThunk>(
           CudnnConvKind::kBackwardInput,
@@ -420,7 +425,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
           /*filter_shape=*/rhs_shape,
           /*output_shape=*/lhs_shape,  //
           custom_call->window(), custom_call->convolution_dimension_numbers(),
-          algorithm, custom_call);
+          algorithm, tensor_ops_enabled, custom_call);
     } else if (target == kCudnnConvBackwardFilterCallTarget) {
       thunk = MakeUnique<ConvolutionThunk>(
           CudnnConvKind::kBackwardFilter,
@@ -433,7 +438,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
           /*filter_shape=*/conv_result_shape,
           /*output_shape=*/rhs_shape,  //
           custom_call->window(), custom_call->convolution_dimension_numbers(),
-          algorithm, custom_call);
+          algorithm, tensor_ops_enabled, custom_call);
     } else {
       LOG(FATAL) << "Unexpected custom call target: "
                  << custom_call->custom_call_target();