Internal change
authorJustin Lebar <jlebar@google.com>
Fri, 2 Feb 2018 05:34:18 +0000 (21:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 18:28:02 +0000 (10:28 -0800)
PiperOrigin-RevId: 184239740

24 files changed:
tensorflow/compiler/xla/service/gpu/BUILD
tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
tensorflow/compiler/xla/service/gpu/convolution_thunk.h
tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc [moved from tensorflow/compiler/xla/service/gpu/convolution_folding.cc with 83% similarity]
tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h [moved from tensorflow/compiler/xla/service/gpu/convolution_folding.h with 63% similarity]
tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc [moved from tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc with 82% similarity]
tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
tensorflow/compiler/xla/service/gpu/ir_emitter.h
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
tensorflow/compiler/xla/service/gpu/pad_insertion.cc
tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_computation.h
tensorflow/compiler/xla/service/hlo_cost_analysis.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h

index 80c2eed..7df01f7 100644 (file)
@@ -131,6 +131,7 @@ cc_library(
         "ir_emitter_context.h",
     ],
     deps = [
+        ":cudnn_convolution_runner",
         ":elemental_ir_emitter",
         ":gpu_constants",
         ":gpu_executable",
@@ -262,6 +263,7 @@ cc_library(
     ],
     deps = [
         ":buffer_allocations",
+        ":cudnn_convolution_runner",
         ":infeed_manager",
         ":ir_emission_utils",
         ":partition_assignment",
@@ -309,9 +311,41 @@ cc_library(
 )
 
 cc_library(
-    name = "convolution_folding",
-    srcs = ["convolution_folding.cc"],
-    hdrs = ["convolution_folding.h"],
+    name = "cudnn_convolution_algorithm_picker",
+    srcs = ["cudnn_convolution_algorithm_picker.cc"],
+    hdrs = ["cudnn_convolution_algorithm_picker.h"],
+    deps = [
+        ":cudnn_convolution_runner",
+        ":gpu_executable",
+        ":ir_emission_utils",
+        "//tensorflow/compiler/xla/service:device_memory_allocator",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_pass",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:stream_executor_no_cuda",
+    ],
+)
+
+cc_library(
+    name = "cudnn_convolution_runner",
+    srcs = ["cudnn_convolution_runner.cc"],
+    hdrs = ["cudnn_convolution_runner.h"],
+    deps = [
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/core:stream_executor_no_cuda",
+    ],
+)
+
+cc_library(
+    name = "cudnn_convolution_rewriter",
+    srcs = ["cudnn_convolution_rewriter.cc"],
+    hdrs = ["cudnn_convolution_rewriter.h"],
     deps = [
         ":ir_emission_utils",
         "//tensorflow/compiler/xla:literal_util",
@@ -325,15 +359,18 @@ cc_library(
 )
 
 tf_cc_test(
-    name = "convolution_folding_test",
-    srcs = ["convolution_folding_test.cc"],
+    name = "cudnn_convolution_rewriter_test",
+    srcs = ["cudnn_convolution_rewriter_test.cc"],
     deps = [
-        ":convolution_folding",
+        ":cudnn_convolution_rewriter",
+        ":ir_emission_utils",
+        "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_matchers",
         "//tensorflow/compiler/xla/service:shape_inference",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
         "//tensorflow/core:test",
     ],
 )
@@ -446,7 +483,8 @@ cc_library(
     srcs = ["gpu_compiler.cc"],
     hdrs = ["gpu_compiler.h"],
     deps = [
-        ":convolution_folding",
+        ":cudnn_convolution_algorithm_picker",
+        ":cudnn_convolution_rewriter",
         ":fusion_merger",
         ":gpu_constants",
         ":gpu_copy_insertion",
index 899cc5c..f76f159 100644 (file)
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include <string>
 
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -36,366 +37,69 @@ using se::dnn::DataLayout;
 using se::dnn::FilterDescriptor;
 using se::dnn::FilterLayout;
 
-ConvolveScratchAllocator::ConvolveScratchAllocator(
-    int device_ordinal, DeviceMemoryAllocator* memory_allocator)
-    : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
-
-ConvolveScratchAllocator::~ConvolveScratchAllocator() {
-  for (auto& allocated_buffer : allocated_buffers_) {
-    if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
-             .ok()) {
-      // The program can still continue with failed deallocation.
-      LOG(ERROR) << "Failed to deallocate the allocated buffer: "
-                 << allocated_buffer.opaque();
-    }
-  }
-}
-
-int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
-  constexpr int64 kConvolveScratchSize = 1LL << 32;  // 4GB by default.
-  return kConvolveScratchSize;
-}
-
-se::port::StatusOr<se::DeviceMemory<uint8>>
-ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) {
-  CHECK_GE(byte_size, 0) << "byte_size must be positive.";
-  if (byte_size > GetMemoryLimitInBytes(stream)) {
-    return se::port::Status(
-        se::port::error::RESOURCE_EXHAUSTED,
-        tensorflow::strings::Printf(
-            "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
-            byte_size, GetMemoryLimitInBytes(stream)));
-  }
-
-  auto status_or_memory =
-      memory_allocator_->Allocate(device_ordinal_, byte_size,
-                                  /*retry_on_failure=*/false);
-  if (!status_or_memory.ok()) {
-    return se::port::Status(se::port::error::RESOURCE_EXHAUSTED,
-                            tensorflow::strings::Printf(
-                                "Failed to allocate %lld bytes on device %d.",
-                                byte_size, device_ordinal_));
-  }
-  se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
-  allocated_buffers_.push_back(allocated_buffer);
-  total_allocated_bytes_ += byte_size;
-  return se::DeviceMemory<uint8>(allocated_buffer);
-}
-
-string ConvolutionKindToString(
-    ConvolutionThunk::ConvolutionKind convolution_kind) {
-  switch (convolution_kind) {
-    case ConvolutionThunk::ConvolutionKind::kForward:
-      return "forward";
-    case ConvolutionThunk::ConvolutionKind::kBackwardFilter:
-      return "backward_filter";
-    case ConvolutionThunk::ConvolutionKind::kBackwardInput:
-      return "backward_input";
-  }
-  return "unknown convolution kind";
-}
-
 ConvolutionThunk::ConvolutionThunk(
-    ConvolutionKind convolution_kind,
-    const BufferAllocation::Slice& input_buffer,
+    CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
     const BufferAllocation::Slice& filter_buffer,
-    const BufferAllocation::Slice& output_buffer, const Shape& input_shape,
+    const BufferAllocation::Slice& output_buffer,
+    const BufferAllocation::Slice& tuple_result_buffer,
+    const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
     const Shape& filter_shape, const Shape& output_shape, const Window& window,
-    const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo)
+    const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
+    const HloInstruction* hlo)
     : Thunk(Kind::kConvolution, hlo),
       convolution_kind_(convolution_kind),
       input_buffer_(input_buffer),
       filter_buffer_(filter_buffer),
       output_buffer_(output_buffer),
+      tuple_result_buffer_(tuple_result_buffer),
+      scratch_buffer_(scratch_buffer),
       input_shape_(input_shape),
       filter_shape_(filter_shape),
       output_shape_(output_shape),
       window_(window),
-      dim_nums_(dim_nums) {}
+      dim_nums_(dim_nums),
+      algorithm_(algorithm) {}
 
-tensorflow::Status ConvolutionThunk::ExecuteOnStream(
+Status ConvolutionThunk::ExecuteOnStream(
     const BufferAllocations& buffer_allocations, se::Stream* stream) {
-  VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_);
-  VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }";
-  VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }";
-  VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }";
-  VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }";
-  VLOG(3) << "Window: { " << window_.ShortDebugString() << " }";
-
-  const int num_dimensions = window_.dimensions_size();
-  CHECK_LE(num_dimensions, 3);
-  // cuDNN does not support 1D convolutions. We therefore express 1D
-  // convolutions as 2D convolutions where the first spatial dimension is 1.
-  // This matches the behavior of TF (see definition of conv1d in
-  // tensorflow/python/ops/nn_ops.py).
-  const int effective_num_dimensions = std::max(2, num_dimensions);
-
-  CHECK_EQ(F32, output_shape_.element_type());
-  CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size());
-  CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size());
-  CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size());
-  for (const WindowDimension& dim : window_.dimensions()) {
-    CHECK_EQ(dim.padding_low(), dim.padding_high());
-  }
-
-  // cuDNN's convolution APIs support the BDYX layout for activations/output and
-  // the OIYX layout for weights.
-  BatchDescriptor input_descriptor(effective_num_dimensions);
-  input_descriptor.set_layout(DataLayout::kBatchDepthYX)
-      .set_feature_map_count(
-          input_shape_.dimensions(dim_nums_.input_feature_dimension()))
-      .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension()));
-  for (int dim = 0; dim < num_dimensions; ++dim) {
-    // Note that the dimensions are reversed. The same holds below.
-    input_descriptor.set_spatial_dim(
-        static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
-        input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim)));
-  }
-
-  FilterDescriptor filter_descriptor(effective_num_dimensions);
-  filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
-      .set_input_feature_map_count(
-          filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension()))
-      .set_output_feature_map_count(filter_shape_.dimensions(
-          dim_nums_.kernel_output_feature_dimension()));
-  for (int dim = 0; dim < num_dimensions; ++dim) {
-    filter_descriptor.set_spatial_dim(
-        static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
-        filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(dim)));
-  }
-
-  ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
-  for (int dim = 0; dim < num_dimensions; ++dim) {
-    convolution_descriptor
-        .set_zero_padding(
-            static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
-            window_.dimensions(dim).padding_low())
-        .set_filter_stride(
-            static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
-            window_.dimensions(dim).stride());
-  }
-
-  BatchDescriptor output_descriptor(effective_num_dimensions);
-  output_descriptor.set_layout(DataLayout::kBatchDepthYX)
-      .set_feature_map_count(
-          output_shape_.dimensions(dim_nums_.output_feature_dimension()))
-      .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension()));
-  for (int dim = 0; dim < num_dimensions; ++dim) {
-    output_descriptor.set_spatial_dim(
-        static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
-        output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim)));
-  }
-
-  // Add a singleton dimension in the 1D convolution case.
-  if (num_dimensions == 1) {
-    input_descriptor.set_spatial_dim(static_cast<se::dnn::DimIndex>(0), 1);
-    output_descriptor.set_spatial_dim(static_cast<se::dnn::DimIndex>(0), 1);
-    filter_descriptor.set_spatial_dim(static_cast<se::dnn::DimIndex>(0), 1);
-    convolution_descriptor
-        .set_zero_padding(static_cast<se::dnn::DimIndex>(0), 0)
-        .set_filter_stride(static_cast<se::dnn::DimIndex>(0), 1);
-  }
-
   se::DeviceMemory<float> input_data(
       buffer_allocations.GetDeviceAddress(input_buffer_));
   se::DeviceMemory<float> filter_data(
       buffer_allocations.GetDeviceAddress(filter_buffer_));
   se::DeviceMemory<float> output_data(
       buffer_allocations.GetDeviceAddress(output_buffer_));
-  return ConvolveWithTune(input_descriptor, input_data, filter_descriptor,
-                          filter_data, output_descriptor, output_data,
-                          convolution_descriptor, buffer_allocations, stream);
-}
-
-tensorflow::Status ConvolutionThunk::Convolve(
-    const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
-    const FilterDescriptor& filter_descriptor,
-    se::DeviceMemory<float> filter_data,
-    const BatchDescriptor& output_descriptor,
-    se::DeviceMemory<float> output_data,
-    const ConvolutionDescriptor& convolution_descriptor,
-    const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream,
-    ConvolveScratchAllocator* scratch_allocator,
-    se::dnn::ProfileResult* profile_result) {
-  bool launch_ok;
-  switch (convolution_kind_) {
-    case ConvolutionKind::kBackwardFilter:
-      launch_ok =
-          stream
-              ->ThenConvolveBackwardFilterWithAlgorithm(
-                  input_descriptor, input_data, output_descriptor, output_data,
-                  convolution_descriptor, filter_descriptor, &filter_data,
-                  scratch_allocator, algorithm_config, profile_result)
-              .ok();
-      break;
-    case ConvolutionKind::kBackwardInput:
-      launch_ok = stream
-                      ->ThenConvolveBackwardDataWithAlgorithm(
-                          filter_descriptor, filter_data, output_descriptor,
-                          output_data, convolution_descriptor, input_descriptor,
-                          &input_data, scratch_allocator, algorithm_config,
-                          profile_result)
-                      .ok();
-      break;
-    case ConvolutionKind::kForward:
-      launch_ok =
-          stream
-              ->ThenConvolveWithAlgorithm(
-                  input_descriptor, input_data, filter_descriptor, filter_data,
-                  convolution_descriptor, output_descriptor, &output_data,
-                  scratch_allocator, algorithm_config, profile_result)
-              .ok();
-      break;
-  }
-  if (launch_ok) {
-    return tensorflow::Status::OK();
-  }
-  return InternalError(
-      "Unable to launch convolution for thunk %p with type %s and algorithm "
-      "(%lld, %lld)",
-      this, ConvolutionKindToString(convolution_kind_).c_str(),
-      algorithm_config.algorithm().algo_id(),
-      algorithm_config.algorithm_no_scratch().algo_id());
-}
-
-std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
-    bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const {
-  std::vector<AlgorithmDesc> algorithms;
-  switch (convolution_kind_) {
-    case ConvolutionKind::kBackwardFilter:
-      CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
-          with_winograd_nonfused, &algorithms));
-      break;
-    case ConvolutionKind::kBackwardInput:
-      CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
-          with_winograd_nonfused, &algorithms));
-      break;
-    case ConvolutionKind::kForward:
-      CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
-                                               &algorithms));
-      break;
-  }
-  return algorithms;
-}
-
-static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) {
-  if (algo.tensor_ops_enabled()) {
-    return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
-  }
-  return tensorflow::strings::StrCat(algo.algo_id());
-}
-
-// Determines whether we can safely perform a winograd non-fused convolution for
-// the given input and output descriptors.  This works around b/68264959, an
-// integer overflow in cuDNNv5 and cuDNNv6.
-static bool ShouldIncludeWinogradNonfusedAlgo(
-    const BatchDescriptor& input_descriptor,
-    const BatchDescriptor& output_descriptor) {
-  int64 batch = input_descriptor.count();
-  int64 in_depths = input_descriptor.feature_map_count();
-  int64 in_rows = input_descriptor.height();
-  int64 in_cols = input_descriptor.width();
-  int64 out_depths = output_descriptor.feature_map_count();
-
-  int64 total_size = 16 * std::ceil(batch / 16.0) *
-                     std::max(in_depths, out_depths) * in_cols * in_rows *
-                     sizeof(float);
-  int64 threshold = 1L << 31;
-
-  return total_size < threshold;
-}
-
-tensorflow::Status ConvolutionThunk::ConvolveWithTune(
-    const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
-    const FilterDescriptor& filter_descriptor,
-    se::DeviceMemory<float> filter_data,
-    const BatchDescriptor& output_descriptor,
-    se::DeviceMemory<float> output_data,
-    const ConvolutionDescriptor& convolution_descriptor,
-    const BufferAllocations& buffer_allocations, se::Stream* stream) {
-  // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out.
-  if (!best_algorithm_.has_value()) {
-    best_algorithm_.emplace();
-
-    // Auto-tuning either is disabled or only happens in the first run of this
-    // function.
-    VLOG(2) << "Profiling for best convolution algorithm used for "
-               "ConvolutionThunk: "
-            << this;
-
-    bool with_winograd_nonfused =
-        ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor);
-
-    se::dnn::ProfileResult best_result;
-    se::dnn::ProfileResult best_result_without_scratch;
-    std::vector<AlgorithmDesc> algorithms =
-        GetAlgorithms(with_winograd_nonfused, stream->parent());
-    for (auto algorithm : algorithms) {
-      ConvolveScratchAllocator scratch_allocator(
-          buffer_allocations.device_ordinal(),
-          buffer_allocations.memory_allocator());
-      se::dnn::ProfileResult profile_result;
-      VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm)
-              << " for ConvolutionThunk: " << this;
-      bool launch_ok =
-          Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
-                   output_descriptor, output_data, convolution_descriptor,
-                   se::dnn::AlgorithmConfig(algorithm, algorithm), stream,
-                   &scratch_allocator, &profile_result)
-              .ok();
-      if (launch_ok && profile_result.is_valid()) {
-        VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
-                << " for ConvolutionThunk " << this << " succeeded, taking "
-                << profile_result.elapsed_time_in_ms()
-                << "ms. (Best result: " << best_result.elapsed_time_in_ms()
-                << "ms)";
-        if (profile_result.elapsed_time_in_ms() <
-            best_result.elapsed_time_in_ms()) {
-          best_result = profile_result;
-        }
-        if (scratch_allocator.TotalAllocatedBytes() == 0 &&
-            profile_result.elapsed_time_in_ms() <
-                best_result_without_scratch.elapsed_time_in_ms()) {
-          best_result_without_scratch = profile_result;
-        }
-      } else {
-        VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
-                << " for ConvolutionThunk " << this << " failed.";
-      }
-    }
-
-    if (best_result.is_valid()) {
-      best_algorithm_->set_algorithm(best_result.algorithm());
-    } else {
-      LOG(ERROR) << "No convolution algorithm works with profiling. Fall back "
-                    "to the default algorithm.";
-      best_algorithm_->set_algorithm(AlgorithmDesc());
+  se::DeviceMemoryBase scratch =
+      buffer_allocations.GetDeviceAddress(scratch_buffer_);
+
+  se::dnn::AlgorithmConfig algorithm_config(
+      se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false));
+
+  TF_RETURN_IF_ERROR(RunCudnnConvolution(
+      convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
+      filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
+      stream));
+
+  // Figure out which of output/input/filter is the result produced by this op,
+  // and write the result tuple.
+  void* result_ptr = [&] {
+    switch (convolution_kind_) {
+      case CudnnConvKind::kForward:
+        return output_data.opaque();
+      case CudnnConvKind::kBackwardInput:
+        return input_data.opaque();
+      case CudnnConvKind::kBackwardFilter:
+        return filter_data.opaque();
     }
+  }();
+  void* ptrs[] = {result_ptr, scratch.opaque()};
+  se::DeviceMemory<void*> tuple_addr(
+      buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
+  stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
 
-    if (best_result_without_scratch.is_valid()) {
-      best_algorithm_->set_algorithm_no_scratch(
-          best_result_without_scratch.algorithm());
-    } else {
-      LOG(ERROR) << "No convolution algorithm without scratch works with "
-                    "profiling. Fall back "
-                    "to the default algorithm.";
-      best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc());
-    }
-  }
-
-  {
-    VLOG(2) << "Using convolution algorithm ("
-            << AlgorithmToString(best_algorithm_->algorithm()) << ", "
-            << AlgorithmToString(best_algorithm_->algorithm_no_scratch())
-            << ") for ConvolutionThunk: " << this;
-    ConvolveScratchAllocator scratch_allocator(
-        buffer_allocations.device_ordinal(),
-        buffer_allocations.memory_allocator());
-    return Convolve(input_descriptor, input_data, filter_descriptor,
-                    filter_data, output_descriptor, output_data,
-                    convolution_descriptor, *best_algorithm_, stream,
-                    &scratch_allocator, nullptr);
+  if (!stream->ok()) {
+    return InternalError("ConvolutionThunk::ExecuteOnStream failed.");
   }
+  return Status::OK();
 }
 
 }  // namespace gpu
index 46c94d0..ca9ef52 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -30,106 +31,47 @@ limitations under the License.
 namespace xla {
 namespace gpu {
 
-// A one-time scratch allocator for forward and backward convolution. The
-// scratch buffers allocated are released on destruction.
-//
-// Not thread-safe.
-class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator {
- public:
-  ConvolveScratchAllocator(int device_ordinal,
-                           DeviceMemoryAllocator* memory_allocator);
-
-  ~ConvolveScratchAllocator() override;
-
-  int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override;
-
-  int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
-
-  perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
-  AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override;
-
- private:
-  const int device_ordinal_;
-  DeviceMemoryAllocator* memory_allocator_;
-  std::vector<perftools::gputools::DeviceMemoryBase> allocated_buffers_;
-  int64 total_allocated_bytes_ = 0;
-};
-
 // This class stores everything that StreamExecutor needs to launch a BNN
 // convolution. It is generated by IrEmitter.
 //
 // This is thread-compatible.
 class ConvolutionThunk : public Thunk {
  public:
-  // ConvolutionThunk performs one of the following types of convolution.
-  enum class ConvolutionKind {
-    kBackwardFilter,  // Backward convolution for filter.
-    kBackwardInput,   // Backward convolution for input.
-    kForward,         // Forward convolution.
-  };
-
-  // Constructs a thunk for launching a DNN convolution.
+  // Constructs a thunk for launching a DNN convolution.  When run, it will
+  // write a tuple (result, scratch_memory) into `tuple_result_buffer`.
+  //
+  // `algorithm` is a cudnn algorithm number.  `algorithm == -1` indicates that
+  // we should use the default (i.e. baseline) cudnn algorithm.
+  //
+  // Note that "output" here doesn't refer to the output from running this
+  // thunk, but rather to the "output" of a hypothetical forward convolution
+  // that corresponds to this input+filter+output triple.  That is, the result
+  // generated by this thunk is "output" for forward convs, "input" for
+  // backward-input convs, and "filter" for backward-filter convs.
+  //
   // Semantics of null hlo_instruction argument are as in Thunk.
-  ConvolutionThunk(ConvolutionKind convolution_kind,
+  ConvolutionThunk(CudnnConvKind convolution_kind,
                    const BufferAllocation::Slice& input_buffer,
                    const BufferAllocation::Slice& filter_buffer,
                    const BufferAllocation::Slice& output_buffer,
+                   const BufferAllocation::Slice& tuple_result_buffer,
+                   const BufferAllocation::Slice& scratch_buffer,
                    const Shape& input_shape, const Shape& filter_shape,
                    const Shape& output_shape, const Window& window,
-                   const ConvolutionDimensionNumbers& dnums,
+                   const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
                    const HloInstruction* hlo);
 
   ConvolutionThunk(const ConvolutionThunk&) = delete;
   ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
 
-  // Does the convolution for the thunk on "stream". Auto-tuning happens on the
-  // first run of this function.
-  tensorflow::Status ExecuteOnStream(
-      const BufferAllocations& buffer_allocations,
-      perftools::gputools::Stream* stream) override;
-
-  // Returns true if the next run of ExecuteOnStream will do autotuning.  If so,
-  // we want the GPU to be quiescent during autotuning, so as not to introduce
-  // noise in our results.
-  bool ShouldHaltAllActivityBeforeRunning(
-      perftools::gputools::Stream*) override {
-    return !best_algorithm_.has_value();
-  }
-
-  // Return true if scratch memory is needed to execute the thunk, that is
-  // either the best algorithm hasn't been chosen or the best algorithm is not
-  // the same as the no-scratch algorithm. This is because that the execution
-  // of the thunk is asynchronous, and the scratch allocator goes out of
-  // scope before the thunk finishes execution. Returning true tells the stream
-  // executor to make future thunks wait for this thunk to avoid reusing the
-  // deallocated scratch memory until this thunk is done with it.
-  bool ShouldBlockFutureThunks() {
-    if (!best_algorithm_.has_value()) {
-      return true;
-    }
-
-    const perftools::gputools::dnn::AlgorithmDesc& best_alg =
-        best_algorithm_->algorithm();
-    const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg =
-        best_algorithm_->algorithm_no_scratch();
-    return (!best_alg.is_default() || !no_scratch_best_alg.is_default() ||
-            !(best_alg == no_scratch_best_alg));
-  }
+  // Does the convolution for the thunk on "stream".
+  Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+                         perftools::gputools::Stream* stream) override;
 
  private:
-  tensorflow::Status ConvolveWithTune(
-      const perftools::gputools::dnn::BatchDescriptor& input_descriptor,
-      perftools::gputools::DeviceMemory<float> input_data,
-      const perftools::gputools::dnn::FilterDescriptor& filter_descriptor,
-      perftools::gputools::DeviceMemory<float> filter_data,
-      const perftools::gputools::dnn::BatchDescriptor& output_descriptor,
-      perftools::gputools::DeviceMemory<float> output_data,
-      const perftools::gputools::dnn::ConvolutionDescriptor&
-          convolution_descriptor,
-      const BufferAllocations& buffer_allocations,
-      perftools::gputools::Stream* stream);
+  class ScratchAllocator;
 
-  tensorflow::Status Convolve(
+  Status Convolve(
       const perftools::gputools::dnn::BatchDescriptor& input_descriptor,
       perftools::gputools::DeviceMemory<float> input_data,
       const perftools::gputools::dnn::FilterDescriptor& filter_descriptor,
@@ -139,40 +81,26 @@ class ConvolutionThunk : public Thunk {
       const perftools::gputools::dnn::ConvolutionDescriptor&
           convolution_descriptor,
       const perftools::gputools::dnn::AlgorithmConfig& algorithm_config,
-      perftools::gputools::Stream* stream,
-      ConvolveScratchAllocator* scratch_allocator,
+      perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator,
       perftools::gputools::dnn::ProfileResult* profile_result);
 
-  // Returns the convolve algorithms that can be used for this ConvolutionThunk.
-  std::vector<perftools::gputools::dnn::AlgorithmDesc> GetAlgorithms(
-      bool with_winograd_nonfused,
-      perftools::gputools::StreamExecutor* stream_exec) const;
-
-  // Fastest cuDNN convolution algorithm for this thunk learned from
-  // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set
-  // to the default value, indicating cuDNN's convolution will choose the best
-  // algorithm from some heuristics based on its parameters.
-  tensorflow::gtl::optional<perftools::gputools::dnn::AlgorithmConfig>
-      best_algorithm_;
-
-  const ConvolutionKind convolution_kind_;
+  const CudnnConvKind convolution_kind_;
 
   const BufferAllocation::Slice input_buffer_;
   const BufferAllocation::Slice filter_buffer_;
   const BufferAllocation::Slice output_buffer_;
+  const BufferAllocation::Slice tuple_result_buffer_;
+  const BufferAllocation::Slice scratch_buffer_;
 
   const Shape input_shape_;
   const Shape filter_shape_;
   const Shape output_shape_;
 
   const Window window_;
-
   const ConvolutionDimensionNumbers dim_nums_;
+  int64 algorithm_;
 };
 
-string ConvolutionKindToString(
-    ConvolutionThunk::ConvolutionKind convolution_kind);
-
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
new file mode 100644 (file)
index 0000000..621b2d5
--- /dev/null
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace se = perftools::gputools;
+
+using se::DeviceMemoryBase;
+using se::dnn::AlgorithmConfig;
+using se::dnn::AlgorithmDesc;
+using tensorflow::gtl::nullopt;
+using tensorflow::gtl::optional;
+
+class ScratchAllocator : public se::ScratchAllocator {
+ public:
+  ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
+      : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
+
+  ~ScratchAllocator() override;
+
+  int64 GetMemoryLimitInBytes(se::Stream* stream) override {
+    return 1LL << 32;  // 4GB.  TODO(jlebar): Tune this?
+  }
+  int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
+
+  se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
+      se::Stream* stream, int64 byte_size) override;
+
+ private:
+  const int device_ordinal_;
+  DeviceMemoryAllocator* memory_allocator_;
+  std::vector<se::DeviceMemoryBase> allocated_buffers_;
+  int64 total_allocated_bytes_ = 0;
+};
+
+ScratchAllocator::~ScratchAllocator() {
+  for (auto& allocated_buffer : allocated_buffers_) {
+    if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
+             .ok()) {
+      // The program can still continue with failed deallocation.
+      LOG(ERROR) << "Failed to deallocate the allocated buffer: "
+                 << allocated_buffer.opaque();
+    }
+  }
+}
+
+se::port::StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
+    se::Stream* stream, int64 byte_size) {
+  CHECK_GE(byte_size, 0) << "byte_size must be positive.";
+  if (byte_size > GetMemoryLimitInBytes(stream)) {
+    return se::port::Status(
+        se::port::error::RESOURCE_EXHAUSTED,
+        tensorflow::strings::Printf(
+            "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+            byte_size, GetMemoryLimitInBytes(stream)));
+  }
+
+  auto status_or_memory =
+      memory_allocator_->Allocate(device_ordinal_, byte_size,
+                                  /*retry_on_failure=*/false);
+  if (!status_or_memory.ok()) {
+    return se::port::Status(se::port::error::RESOURCE_EXHAUSTED,
+                            tensorflow::strings::Printf(
+                                "Failed to allocate %lld bytes on device %d.",
+                                byte_size, device_ordinal_));
+  }
+  se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
+  allocated_buffers_.push_back(allocated_buffer);
+  total_allocated_bytes_ += byte_size;
+  return se::DeviceMemory<uint8>(allocated_buffer);
+}
+
+// Determines whether we can safely perform a winograd non-fused convolution for
+// the given input and output shapes.  This works around b/68264959, an integer
+// overflow in cuDNNv5 and cuDNNv6.
+//
+// TODO(jlebar): We shouldn't need this check for cuDNNv7.
+bool ShouldIncludeWinogradNonfusedAlgo(
+    const Shape& input_shape, const Shape& output_shape,
+    const ConvolutionDimensionNumbers& dnums) {
+  int64 batch = input_shape.dimensions(dnums.input_batch_dimension());
+  int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension());
+  int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0));
+  int64 in_cols =
+      dnums.input_spatial_dimensions_size() == 1
+          ? 1
+          : input_shape.dimensions(dnums.input_spatial_dimensions(1));
+  int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension());
+
+  int64 total_size = CeilOfRatio(batch, int64{16}) *
+                     std::max(in_depths, out_depths) * in_cols * in_rows *
+                     sizeof(float);
+
+  const int64 threshold = 1L << 31;
+  return total_size < threshold;
+}
+
+std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
+                                         bool with_winograd_nonfused,
+                                         se::StreamExecutor* stream_exec_) {
+  std::vector<AlgorithmDesc> algorithms;
+  switch (kind) {
+    case CudnnConvKind::kBackwardFilter:
+      CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms(
+          with_winograd_nonfused, &algorithms));
+      break;
+    case CudnnConvKind::kBackwardInput:
+      CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms(
+          with_winograd_nonfused, &algorithms));
+      break;
+    case CudnnConvKind::kForward:
+      CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused,
+                                                &algorithms));
+      break;
+  }
+
+  // Remove any algorithms with tensor math enabled.  These have lower precision
+  // than regular algorithms, and we don't yet have a way to turn this on/off in
+  // XLA.
+  algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(),
+                                  [&](const AlgorithmDesc& a) {
+                                    return a.tensor_ops_enabled();
+                                  }),
+                   algorithms.end());
+
+  return algorithms;
+}
+
+string AlgorithmToString(const AlgorithmDesc& algo) {
+  if (algo.tensor_ops_enabled()) {
+    return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
+  }
+  return tensorflow::strings::StrCat(algo.algo_id());
+}
+
+string NumBytesToString(int64 bytes) {
+  return tensorflow::strings::StrCat(
+      tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
+}
+
+}  // anonymous namespace
+
+// We could have caching here so that we don't redo this work for two identical
+// convolutions.  Unfortunately our cache key would have to be a tuple
+// containing the protos passed to this function, and we have no utility for
+// hashing protos.  We could write our own hash functions, but they'd silently
+// break if we ever added a field to one of the protos.  Perhaps we could hack
+// using the binary-encoded proto as the hash key, on the assumption that two
+// protos being binary-equal is a sufficient, if not necessary, condition for
+// proper equality.  But that would still leave us open to having unnecessary
+// cache misses and doing extra work.  Overall, caching doesn't seem worth the
+// trouble, but we may want to revisit this if we ever find a model where
+// caching would speed up compilation a lot.
+optional<std::pair<int64, int64>>
+CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
+    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+    const Shape& output_shape, const Window& window,
+    const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
+  // Create a stream for us to do our work on.
+  se::Stream stream{stream_exec_};
+  stream.Init();
+  const auto device_ordinal = stream_exec_->device_ordinal();
+
+  // allocator either points to this->allocator_ or, if that's null, to a
+  // StreamExecutorMemoryAllocator for stream_exec_.
+  DeviceMemoryAllocator* allocator;
+  optional<StreamExecutorMemoryAllocator> se_allocator;
+  if (allocator_ != nullptr) {
+    allocator = allocator_;
+  } else {
+    se_allocator.emplace(
+        stream_exec_->platform(),
+        tensorflow::gtl::ArraySlice<se::StreamExecutor*>({stream_exec_}));
+    allocator = &*se_allocator;
+  }
+
+  // Allocate space for the input, filter, and output of the convolution.  We
+  // use a ScratchAllocator for this instead of calling allocator_ directly so
+  // that our allocations don't leak.
+  //
+  // We don't put any data in these buffers, because (in theory, anyway) the
+  // speed of a conv isn't affected by the data being convolved.
+  ScratchAllocator input_output_allocator(device_ordinal, allocator);
+  se::port::StatusOr<DeviceMemoryBase> input_buf =
+      input_output_allocator.AllocateBytes(&stream,
+                                           ShapeUtil::ByteSizeOf(input_shape));
+  se::port::StatusOr<DeviceMemoryBase> filter_buf =
+      input_output_allocator.AllocateBytes(&stream,
+                                           ShapeUtil::ByteSizeOf(filter_shape));
+  se::port::StatusOr<DeviceMemoryBase> output_buf =
+      input_output_allocator.AllocateBytes(&stream,
+                                           ShapeUtil::ByteSizeOf(output_shape));
+  if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) {
+    LOG(WARNING)
+        << "Couldn't allocate space for input/filter/output of convolution "
+        << instr->ToString() << ".  Falling back to default algorithm.";
+    return nullopt;
+  }
+
+  const bool use_winograd_nonfused =
+      ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums);
+  se::dnn::ProfileResult best_result;
+  int64 best_result_bytes_used = 0;
+  for (const AlgorithmDesc& alg :
+       GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
+    ScratchAllocator scratch_allocator(device_ordinal, allocator);
+    se::dnn::ProfileResult profile_result;
+    VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
+            << instr->ToString();
+
+    bool launch_ok =
+        RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+                            se::DeviceMemory<float>(input_buf.ValueOrDie()),
+                            se::DeviceMemory<float>(filter_buf.ValueOrDie()),
+                            se::DeviceMemory<float>(output_buf.ValueOrDie()),
+                            &scratch_allocator, window, dnums,
+                            AlgorithmConfig(alg), &stream, &profile_result)
+            .ok();
+
+    if (launch_ok && profile_result.is_valid()) {
+      int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
+      VLOG(3) << "Run of algorithm " << AlgorithmToString(alg)
+              << " succeeded, taking " << profile_result.elapsed_time_in_ms()
+              << "ms and using " << NumBytesToString(scratch_bytes_used)
+              << " of scratch (Best result: "
+              << best_result.elapsed_time_in_ms() << "ms, "
+              << NumBytesToString(best_result_bytes_used) << " of scratch)";
+      if (profile_result.elapsed_time_in_ms() <
+          best_result.elapsed_time_in_ms()) {
+        best_result = profile_result;
+        best_result_bytes_used = scratch_bytes_used;
+      }
+    } else {
+      VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed.";
+    }
+  }
+  if (best_result.is_valid()) {
+    VLOG(2) << "Best algorithm for " << instr->ToString() << ": "
+            << AlgorithmToString(best_result.algorithm()) << ", takes "
+            << best_result.elapsed_time_in_ms() << "ms, and uses "
+            << best_result_bytes_used << "B of scratch memory.";
+    return std::make_pair(best_result.algorithm().algo_id(),
+                          best_result_bytes_used);
+  }
+
+  LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
+               << " failed.  Falling back to default algorithm.";
+  return nullopt;
+}
+
+StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
+    HloInstruction* instr) {
+  CHECK(IsCustomCallToDnnConvolution(*instr));
+
+  const auto& call_target = instr->custom_call_target();
+  const auto& lhs_shape = instr->operand(0)->shape();
+  const auto& rhs_shape = instr->operand(1)->shape();
+  const auto& conv_result_shape = instr->shape().tuple_shapes(0);
+  optional<std::pair<int64, int64>> alg_and_scratch_bytes;
+  if (call_target == kCudnnConvForwardCallTarget) {
+    alg_and_scratch_bytes = PickBestAlgorithm(
+        CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
+        /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
+        instr->window(), instr->convolution_dimension_numbers(), instr);
+  } else if (call_target == kCudnnConvBackwardInputCallTarget) {
+    alg_and_scratch_bytes = PickBestAlgorithm(
+        CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
+        /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
+        instr->convolution_dimension_numbers(), instr);
+  } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
+    alg_and_scratch_bytes = PickBestAlgorithm(
+        CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
+        /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
+        instr->window(), instr->convolution_dimension_numbers(), instr);
+  } else {
+    LOG(FATAL) << "Unknown custom call target for cudnn conv: "
+               << instr->ToString();
+  }
+
+  if (!alg_and_scratch_bytes.has_value()) {
+    return false;
+  }
+
+  int64 algorithm;
+  int64 scratch_bytes;
+  std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes;
+
+  VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
+          << NumBytesToString(scratch_bytes)
+          << " of scratch memory: " << instr->ToString();
+
+  // Replace instr with a new CustomCall which has the correct algorithm, and
+  // whose output shape has the appropriate amount of scratch memory.
+  HloComputation* computation = instr->parent();
+  Shape new_call_shape =
+      ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
+                                 ShapeUtil::MakeShape(U8, {scratch_bytes})});
+  HloInstruction* algorithm_hlo = computation->AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<int64>(algorithm)));
+  HloInstruction* new_call =
+      computation->AddInstruction(HloInstruction::CreateCustomCall(
+          new_call_shape,
+          {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo},
+          instr->custom_call_target()));
+  new_call->set_window(instr->window());
+  new_call->set_convolution_dimension_numbers(
+      instr->convolution_dimension_numbers());
+
+  // Repackage new_call so it has the same shape as the original call, namely
+  // (conv_result, u8[0]).
+  HloInstruction* new_tuple =
+      computation->AddInstruction(HloInstruction::CreateTuple(
+          {computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+               new_call_shape.tuple_shapes(0), new_call, 0)),
+           computation->AddInstruction(
+               HloInstruction::CreateConstant(Literal::CreateR1<uint8>({})))}));
+
+  TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
+  return true;
+}
+
+StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnComputation(
+    HloComputation* computation) {
+  std::vector<HloInstruction*> convs;
+  for (auto* instr : computation->instructions()) {
+    if (IsCustomCallToDnnConvolution(*instr)) {
+      convs.push_back(instr);
+    }
+  }
+
+  bool changed = false;
+  for (auto* instr : convs) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
+    changed |= result;
+  }
+  return changed;
+}
+
+StatusOr<bool> CudnnConvolutionAlgorithmPicker::Run(HloModule* module) {
+  bool changed = false;
+  for (HloComputation* computation : module->MakeNonfusionComputations()) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
new file mode 100644 (file)
index 0000000..10e49da
--- /dev/null
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
+// each and adding explicit scratch space to the CustomCalls.
+class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
+ public:
+  // If the `allocator` parameter is not null, we will use it to allocate temp
+  // memory while timing the various convolution algorithms.  If it's null,
+  // we'll use the default allocator on the StreamExecutor.
+  CudnnConvolutionAlgorithmPicker(
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* allocator)
+      : stream_exec_(stream_exec), allocator_(allocator) {}
+
+  tensorflow::StringPiece name() const override {
+    return "cudnn-convolution-algorithm-picker";
+  }
+
+  StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+  StatusOr<bool> RunOnComputation(HloComputation* computation);
+  StatusOr<bool> RunOnInstruction(HloInstruction* instr);
+  tensorflow::gtl::optional<std::pair<int64, int64>> PickBestAlgorithm(
+      CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+      const Shape& output_shape, const Window& window,
+      const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
+
+  perftools::gputools::StreamExecutor* stream_exec_;  // never null
+  DeviceMemoryAllocator* allocator_;                  // may be null
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
 
 #include <numeric>
 #include <vector>
@@ -33,14 +33,32 @@ namespace xla {
 namespace gpu {
 
 namespace {
+
+bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
+  const ConvolutionDimensionNumbers& dnums =
+      conv->convolution_dimension_numbers();
+  if (dnums.input_spatial_dimensions_size() > 3) {
+    return false;
+  }
+
+  // CuDNN does not accept zero-element arguments
+  if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) ||
+      ShapeUtil::HasZeroElements(conv->operand(1)->shape())) {
+    return false;
+  }
+
+  if (window_util::HasWindowReversal(conv->window())) {
+    return false;
+  }
+  return true;
+}
+
 // Try to match a backward filter pattern that contains "conv".
 // Precondition: "conv" is a kConvolution.
-std::tuple<bool, std::vector<HloInstruction*>, Window,
-           ConvolutionDimensionNumbers>
-MatchBackwardFilter(HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
+    HloInstruction* conv) {
   const auto no_match_result =
-      std::make_tuple(false, std::vector<HloInstruction*>(), Window(),
-                      ConvolutionDimensionNumbers());
+      std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
   // Step 1: match the instruction pattern without considering the paddings and
   // dimension numbers just yet. We may need some generic pattern matcher
   // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@@ -190,18 +208,15 @@ MatchBackwardFilter(HloInstruction* conv) {
     backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
   }
 
-  return std::make_tuple(true, std::vector<HloInstruction*>({conv}),
-                         backward_conv_window, backward_conv_dnums);
+  return std::make_tuple(true, backward_conv_window, backward_conv_dnums);
 }
 
 // Try to match a backward input pattern that contains "conv".
 // Precondition: "conv" is a kConvolution.
-std::tuple<bool, std::vector<HloInstruction*>, Window,
-           ConvolutionDimensionNumbers>
-MatchBackwardInput(HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
+    HloInstruction* conv) {
   const auto no_match_result =
-      std::make_tuple(false, std::vector<HloInstruction*>(), Window(),
-                      ConvolutionDimensionNumbers());
+      std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
 
   // Match instruction pattern.
   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
@@ -374,58 +389,82 @@ MatchBackwardInput(HloInstruction* conv) {
   dnums.set_kernel_output_feature_dimension(
       conv->convolution_dimension_numbers().kernel_input_feature_dimension());
 
-  return std::make_tuple(true,
-                         std::vector<HloInstruction*>({conv, reverse_filter}),
-                         new_window, dnums);
+  return std::make_tuple(true, new_window, dnums);
 }
-}  // namespace
 
-StatusOr<bool> ConvolutionFolding::Run(HloModule* module) {
-  HloComputation* entry_computation = module->entry_computation();
-  std::vector<HloInstruction*> convs;
-  for (auto* hlo : entry_computation->instructions()) {
-    if (hlo->opcode() == HloOpcode::kConvolution) {
-      convs.push_back(hlo);
-    }
-  }
+// Tries to rewrite a single convolution into a call to cudnn.
+StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
+  CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
 
-  bool changed = false;
-  for (HloInstruction* conv : convs) {
+  HloInstruction* custom_call = [&]() -> HloInstruction* {
     bool match;
-    std::vector<HloInstruction*> hlos_to_fuse;
     Window window;
     ConvolutionDimensionNumbers dnums;
-    std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv);
+
+    std::tie(match, window, dnums) = MatchBackwardFilter(conv);
     if (match) {
-      VLOG(2) << "Fuse instructions";
-      for (HloInstruction* hlo_to_fuse : hlos_to_fuse) {
-        VLOG(2) << "  " << hlo_to_fuse->ToString();
-      }
-      HloInstruction* backward_convolution =
-          entry_computation->CreateFusionInstructionForBackwardConvolution(
-              hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter,
-              window, dnums);
-      VLOG(2) << "to backward filter convolution";
-      VLOG(2) << "  " << backward_convolution->ToString();
-      changed = true;
-      continue;
+      return CreateCudnnConvBackwardFilter(
+          conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
+          window, dnums);
     }
 
-    std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv);
+    std::tie(match, window, dnums) = MatchBackwardInput(conv);
     if (match) {
-      VLOG(2) << "Fuse instructions";
-      for (HloInstruction* hlo_to_fuse : hlos_to_fuse) {
-        VLOG(2) << "  " << hlo_to_fuse->ToString();
-      }
-      HloInstruction* backward_convolution =
-          entry_computation->CreateFusionInstructionForBackwardConvolution(
-              hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput,
-              window, dnums);
-      VLOG(2) << "to backward input convolution";
-      VLOG(2) << "  " << backward_convolution->ToString();
-      changed = true;
-      continue;
+      // Backward input conv subsumes the conv plus the reverse in operand 1.
+      HloInstruction* reverse = conv->mutable_operand(1);
+      CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
+      HloInstruction* rhs = reverse->mutable_operand(0);
+
+      return CreateCudnnConvBackwardInput(
+          conv->shape(), conv->mutable_operand(0), rhs, window, dnums);
     }
+
+    // If all else fails, try a forward convolution.
+    if (CanImplementAsCudnnForwardConv(conv)) {
+      return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
+                                    conv->mutable_operand(1), conv->window(),
+                                    conv->convolution_dimension_numbers());
+    }
+
+    return nullptr;
+  }();
+
+  if (custom_call == nullptr) {
+    return false;
+  }
+
+  // The CustomCall returns a tuple (conv_result, scratch_memory).  Extract out
+  // the conv result and replace `conv` with it.
+  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+      conv,
+      HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
+  return true;
+}
+
+// Rewrites the convolutions in the given computation into calls to cudnn.
+// Returns true if it made any changes.
+StatusOr<bool> RunOnComputation(HloComputation* computation) {
+  std::vector<HloInstruction*> convs;
+  for (auto* hlo : computation->instructions()) {
+    if (hlo->opcode() == HloOpcode::kConvolution) {
+      convs.push_back(hlo);
+    }
+  }
+
+  bool changed = false;
+  for (HloInstruction* conv : convs) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv));
+    changed |= result;
+  }
+  return changed;
+}
+}  // namespace
+
+StatusOr<bool> CudnnConvolutionRewriter::Run(HloModule* module) {
+  bool changed = false;
+  for (HloComputation* computation : module->MakeNonfusionComputations()) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
   }
   return changed;
 }
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_
 
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -22,10 +22,12 @@ limitations under the License.
 namespace xla {
 namespace gpu {
 
-class ConvolutionFolding : public HloPassInterface {
+// Rewrites plain convolutions, backwards-filter convolutions, and
+// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
+class CudnnConvolutionRewriter : public HloPassInterface {
  public:
   tensorflow::StringPiece name() const override {
-    return "convolution-folding";
+    return "cudnn-convolution-rewriter";
   }
 
   StatusOr<bool> Run(HloModule* module) override;
@@ -34,4 +36,4 @@ class ConvolutionFolding : public HloPassInterface {
 }  // namespace gpu
 }  // namespace xla
 
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,23 +13,29 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
 
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace xla {
 namespace gpu {
+namespace {
 
-class ConvolutionFoldingTest : public HloTestBase {
+namespace op = xla::testing::opcode_matchers;
+
+class CudnnConvolutionRewriterTest : public HloTestBase {
  public:
-  ConvolutionFoldingTest() {
+  CudnnConvolutionRewriterTest() {
     for (int i = 0; i < 2; ++i) {
       WindowDimension* window_dim = default_conv_window_.add_dimensions();
       window_dim->set_size(1);
@@ -44,7 +50,8 @@ class ConvolutionFoldingTest : public HloTestBase {
     // the batch and feature dimension in the activations, and treat the batch
     // dimension in gradients as the input feature dimension in the filter.
     //
-    // TODO(jingyue): Add more tests on NCHW input order which TF also supports.
+    // TODO(jingyue): Add more tests on NCHW input order, which TF also
+    // supports.
     tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
     tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
     tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
@@ -74,9 +81,8 @@ class ConvolutionFoldingTest : public HloTestBase {
   }
 
  protected:
-  bool FoldConvolution(HloModule* module) {
-    ConvolutionFolding convolution_folding;
-    return convolution_folding.Run(module).ValueOrDie();
+  bool RunPass(HloModule* module) {
+    return CudnnConvolutionRewriter().Run(module).ValueOrDie();
   }
 
   // A convolution window with stride 1 and zero padding. The size fields are
@@ -86,7 +92,7 @@ class ConvolutionFoldingTest : public HloTestBase {
   ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
 };
 
-TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) {
+TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
   HloComputation::Builder builder(TestName());
   HloInstruction* activations =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -108,14 +114,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
 }
 
-TEST_F(ConvolutionFoldingTest,
+TEST_F(CudnnConvolutionRewriterTest,
        BackwardFilterConvolveEquivalentToForwardConvolution) {
   HloComputation::Builder builder(TestName());
   HloInstruction* activations =
@@ -135,12 +140,17 @@ TEST_F(ConvolutionFoldingTest,
       tf_default_dnums_for_backward_filter_));
 
   auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
 }
 
 // Extracted from block35 training.
-TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) {
+TEST_F(CudnnConvolutionRewriterTest,
+       BackwardFilterConvolveWithPaddedActivations) {
   auto builder = HloComputation::Builder(TestName());
   HloInstruction* activations =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -162,15 +172,15 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
 }
 
 // Extracted from inception v3 training.
-TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) {
+TEST_F(CudnnConvolutionRewriterTest,
+       BackwardFilterConvolveWithPaddedGradients) {
   auto builder = HloComputation::Builder(TestName());
   HloInstruction* activations =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -192,14 +202,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
 }
 
-TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) {
+TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
   auto builder = HloComputation::Builder(TestName());
   HloInstruction* activations =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -221,14 +230,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
 }
 
-TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) {
+TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
   auto builder = HloComputation::Builder(TestName());
   HloInstruction* output =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -272,14 +280,15 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+
+  ASSERT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
+  const HloInstruction* custom_call =
+      entry_computation->root_instruction()->operand(0);
   for (int i = 0; i < 2; ++i) {
-    const WindowDimension& window_dim =
-        entry_computation->root_instruction()->window().dimensions(i);
+    const WindowDimension& window_dim = custom_call->window().dimensions(i);
     // Low padding of the backward input convolution
     //   = kernel_size - 1 - low padding on gradients.
     EXPECT_EQ(3, window_dim.padding_low());
@@ -291,7 +300,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) {
 // Convolve([abc], [x], base_dilation=2)
 //   = Convolve([abc], Reverse([x]), base_dilation=2)
 //   = BackwardInputConvolve([abc], [x], stride=2)
-TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) {
+TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
   auto builder = HloComputation::Builder(TestName());
   // NHWC dimension order.
   HloInstruction* output =
@@ -316,17 +325,16 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
 }
 
 // BackwardInputConvolve([abc], [x], stride=1) is equivalent to
 // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
 // convolution.
-TEST_F(ConvolutionFoldingTest,
+TEST_F(CudnnConvolutionRewriterTest,
        BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
   auto builder = HloComputation::Builder(TestName());
   // NHWC dimension order.
@@ -347,8 +355,12 @@ TEST_F(ConvolutionFoldingTest,
       tf_default_dnums_for_backward_input_));
 
   auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_FALSE(FoldConvolution(module.get()));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(
+      entry_computation->root_instruction(),
+      op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
 }
 
 // Extracted from Inception V3 training.
@@ -365,7 +377,8 @@ TEST_F(ConvolutionFoldingTest,
 //                     20x10x10x192
 //
 // Gradients are padded unevenly.
-TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) {
+TEST_F(CudnnConvolutionRewriterTest,
+       BackwardInputConvolveUnevenPaddingOnGradients) {
   auto builder = HloComputation::Builder(TestName());
   HloInstruction* output =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -397,14 +410,14 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) {
   auto module = CreateNewModule();
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  EXPECT_EQ(HloOpcode::kFusion,
-            entry_computation->root_instruction()->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
-              entry_computation->root_instruction()->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  ASSERT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
+  const HloInstruction* custom_call =
+      entry_computation->root_instruction()->operand(0);
   for (int i = 0; i < 2; ++i) {
-    const WindowDimension& window_dim =
-        entry_computation->root_instruction()->window().dimensions(i);
+    const WindowDimension& window_dim = custom_call->window().dimensions(i);
     EXPECT_EQ(0, window_dim.padding_low());
     EXPECT_EQ(0, window_dim.padding_high());
     EXPECT_EQ(2, window_dim.stride());
@@ -413,7 +426,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) {
 
 // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
 // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
-TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) {
+TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
   auto builder = HloComputation::Builder(TestName());
   HloInstruction* output =
       builder.AddInstruction(HloInstruction::CreateParameter(
@@ -442,8 +455,12 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) {
                          .ValueOrDie()));
 
   auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_FALSE(FoldConvolution(module.get()));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(
+      entry_computation->root_instruction(),
+      op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
 }
 
 // Extracted from //learning/brain/google/xla/benchmarks/resnet.py
@@ -460,7 +477,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) {
 //
 // We should fuse BC even though padding on activations is uneven, because
 // PadInsertion will canonicalize the fusion HLO.
-TEST_F(ConvolutionFoldingTest,
+TEST_F(CudnnConvolutionRewriterTest,
        BackwardInputConvolveUnevenPaddingOnActivations) {
   auto builder = HloComputation::Builder(TestName());
   // The gradients are in NCHW layout.
@@ -493,13 +510,12 @@ TEST_F(ConvolutionFoldingTest,
   auto module = CreateNewModule();
   const HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(FoldConvolution(module.get()));
-  const HloInstruction* backward_conv = entry_computation->root_instruction();
-  EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode());
-  EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
-              backward_conv->fusion_kind());
+  EXPECT_TRUE(RunPass(module.get()));
+  ASSERT_THAT(entry_computation->root_instruction(),
+              op::GetTupleElement(
+                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
   const WindowDimension& backward_conv_col_dim =
-      backward_conv->window().dimensions(1);
+      entry_computation->root_instruction()->operand(0)->window().dimensions(1);
   EXPECT_EQ(0, backward_conv_col_dim.padding_low());
   EXPECT_EQ(1, backward_conv_col_dim.padding_high());
 }
@@ -515,7 +531,7 @@ TEST_F(ConvolutionFoldingTest,
 //
 // We currently don't fuse BC because PadInsertion doesn't support negative
 // padding on the gradients of backward convolution (b/32744257).
-TEST_F(ConvolutionFoldingTest,
+TEST_F(CudnnConvolutionRewriterTest,
        BackwardInputConvolveNegativePaddingHighOnActivations) {
   auto builder = HloComputation::Builder(TestName());
   // The gradients are in NCHW layout.
@@ -544,9 +560,14 @@ TEST_F(ConvolutionFoldingTest,
                          .ValueOrDie()));
 
   auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_FALSE(FoldConvolution(module.get()));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(
+      entry_computation->root_instruction(),
+      op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
 }
 
+}  // anonymous namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
new file mode 100644 (file)
index 0000000..f5f52cf
--- /dev/null
@@ -0,0 +1,221 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace se = ::perftools::gputools;
+
+using se::DeviceMemory;
+using se::DeviceMemoryBase;
+using se::Stream;
+using se::dnn::AlgorithmConfig;
+using se::dnn::BatchDescriptor;
+using se::dnn::ConvolutionDescriptor;
+using se::dnn::DataLayout;
+using se::dnn::DimIndex;
+using se::dnn::FilterDescriptor;
+using se::dnn::FilterLayout;
+using se::dnn::ProfileResult;
+
+// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
+// returning it (in its entirety) the first time Allocate() is called.
+class ScratchBufAllocator : public se::ScratchAllocator {
+ public:
+  explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
+      : scratch_(scratch) {}
+
+  ~ScratchBufAllocator() override = default;
+
+  int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override {
+    return scratch_.size();
+  }
+
+  se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
+      se::Stream* stream, int64 byte_size) override {
+    if (allocated_) {
+      return se::port::InternalError(
+          "Can't allocate twice from a ScratchBufAllocator.");
+    }
+    if (byte_size > scratch_.size()) {
+      return se::port::InternalError(tensorflow::strings::StrCat(
+          "Can't allocate ", byte_size,
+          " bytes from a ScratchBufAllocator of size ", scratch_.size()));
+    }
+
+    allocated_ = true;
+    return se::DeviceMemory<uint8>(scratch_);
+  }
+
+ private:
+  se::DeviceMemoryBase scratch_;
+  bool allocated_ = false;
+};
+
+}  // anonymous namespace
+
+string CudnnConvKindToString(CudnnConvKind kind) {
+  switch (kind) {
+    case CudnnConvKind::kForward:
+      return "forward";
+    case CudnnConvKind::kBackwardFilter:
+      return "backward_filter";
+    case CudnnConvKind::kBackwardInput:
+      return "backward_input";
+  }
+}
+
+Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape,
+                           const Shape& filter_shape, const Shape& output_shape,
+                           DeviceMemory<float> input_buf,
+                           DeviceMemory<float> filter_buf,
+                           DeviceMemory<float> output_buf,
+                           DeviceMemoryBase scratch_buf, const Window& window,
+                           const ConvolutionDimensionNumbers& dnums,
+                           AlgorithmConfig algorithm, Stream* stream,
+                           ProfileResult* profile_result /*= nullptr*/) {
+  ScratchBufAllocator scratch_allocator(scratch_buf);
+  return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+                             input_buf, filter_buf, output_buf,
+                             &scratch_allocator, window, dnums, algorithm,
+                             stream, profile_result);
+}
+
+Status RunCudnnConvolution(
+    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+    const Shape& output_shape, DeviceMemory<float> input_buf,
+    DeviceMemory<float> filter_buf, DeviceMemory<float> output_buf,
+    se::ScratchAllocator* scratch_allocator, const Window& window,
+    const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
+    Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+  VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
+  VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
+  VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
+  VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }";
+  VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
+  VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
+
+  const int num_dimensions = window.dimensions_size();
+  CHECK_LE(num_dimensions, 3);
+  // cuDNN does not support 1D convolutions. We therefore express 1D
+  // convolutions as 2D convolutions where the first spatial dimension is 1.
+  // This matches the behavior of TF (see definition of conv1d in
+  // tensorflow/python/ops/nn_ops.py).
+  const int effective_num_dimensions = std::max(2, num_dimensions);
+
+  CHECK_EQ(F32, output_shape.element_type())
+      << ShapeUtil::HumanString(output_shape);
+  CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
+  CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
+  CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size());
+  for (const WindowDimension& dim : window.dimensions()) {
+    CHECK_EQ(dim.padding_low(), dim.padding_high());
+  }
+
+  // cuDNN's convolution APIs support the BDYX layout for activations/output and
+  // the OIYX layout for weights.
+  BatchDescriptor input_descriptor(effective_num_dimensions);
+  input_descriptor.set_layout(DataLayout::kBatchDepthYX)
+      .set_feature_map_count(
+          input_shape.dimensions(dnums.input_feature_dimension()))
+      .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
+  for (int dim = 0; dim < num_dimensions; ++dim) {
+    // Note that the dimensions are reversed. The same holds below.
+    input_descriptor.set_spatial_dim(
+        static_cast<DimIndex>(effective_num_dimensions - dim - 1),
+        input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
+  }
+
+  FilterDescriptor filter_descriptor(effective_num_dimensions);
+  filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
+      .set_input_feature_map_count(
+          filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
+      .set_output_feature_map_count(
+          filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
+  for (int dim = 0; dim < num_dimensions; ++dim) {
+    filter_descriptor.set_spatial_dim(
+        static_cast<DimIndex>(effective_num_dimensions - dim - 1),
+        filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
+  }
+
+  ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
+  for (int dim = 0; dim < num_dimensions; ++dim) {
+    convolution_descriptor
+        .set_zero_padding(
+            static_cast<DimIndex>(effective_num_dimensions - dim - 1),
+            window.dimensions(dim).padding_low())
+        .set_filter_stride(
+            static_cast<DimIndex>(effective_num_dimensions - dim - 1),
+            window.dimensions(dim).stride());
+  }
+
+  BatchDescriptor output_descriptor(effective_num_dimensions);
+  output_descriptor.set_layout(DataLayout::kBatchDepthYX)
+      .set_feature_map_count(
+          output_shape.dimensions(dnums.output_feature_dimension()))
+      .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
+  for (int dim = 0; dim < num_dimensions; ++dim) {
+    output_descriptor.set_spatial_dim(
+        static_cast<DimIndex>(effective_num_dimensions - dim - 1),
+        output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
+  }
+
+  // Add a singleton dimension in the 1D convolution case.
+  if (num_dimensions == 1) {
+    input_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
+    output_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
+    filter_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
+    convolution_descriptor.set_zero_padding(static_cast<DimIndex>(0), 0)
+        .set_filter_stride(static_cast<DimIndex>(0), 1);
+  }
+
+  switch (kind) {
+    case CudnnConvKind::kForward:
+      stream->ThenConvolveWithAlgorithm(
+          input_descriptor, input_buf, filter_descriptor, filter_buf,
+          convolution_descriptor, output_descriptor, &output_buf,
+          scratch_allocator, algorithm, profile_result);
+      break;
+    case CudnnConvKind::kBackwardInput:
+      stream->ThenConvolveBackwardDataWithAlgorithm(
+          filter_descriptor, filter_buf, output_descriptor, output_buf,
+          convolution_descriptor, input_descriptor, &input_buf,
+          scratch_allocator, algorithm, profile_result);
+      break;
+    case CudnnConvKind::kBackwardFilter:
+      stream->ThenConvolveBackwardFilterWithAlgorithm(
+          input_descriptor, input_buf, output_descriptor, output_buf,
+          convolution_descriptor, filter_descriptor, &filter_buf,
+          scratch_allocator, algorithm, profile_result);
+      break;
+  }
+
+  if (!stream->ok()) {
+    return InternalError(
+        "Unable to launch convolution with type %s and algorithm (%lld, %lld)",
+        CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(),
+        algorithm.algorithm_no_scratch().algo_id());
+  }
+  return Status::OK();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
new file mode 100644 (file)
index 0000000..b101f76
--- /dev/null
@@ -0,0 +1,97 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
+
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// This file contains low-level routines for running cudnn convolutions.
+
+// Different types of convolutions supported by cudnn.
+//
+// A way to think about these is that a convolution is defined by three arrays
+// -- the "input", the "filter", and the "output" -- and given any two of these,
+// we can compute the third.  For example, a backward-input convolution takes as
+// input a filter and an "output" and produces an "input" such that if one were
+// to do a forward convolution of "input" using filter, the result would be
+// something with the same shape as "output".
+//
+// This way of thinking is not correct if you look at the values produced. For
+// example, a backward-input convolution is not actually the mathematical
+// inverse of a forward convolution.  But it's right as far as the shapes and
+// "connectivity" (i.e. which elements of the input affect which elements of
+// the output) are concerned.
+enum class CudnnConvKind {
+  kForward,         // input  + filter => output
+  kBackwardInput,   // filter + output => input
+  kBackwardFilter,  // input  + output => filter
+};
+
+// Converts a CudnnConvKind value to a string.
+string CudnnConvKindToString(CudnnConvKind kind);
+
+// Calls into cudnn to run the specified convolution.
+//
+// Note that depending on the value of CudnnConvKind, the result of this call
+// may be written into input_buf, filter_buf, or output_buf!
+//
+// At the moment we only support cudnn convolutions over floats.
+//
+// We provide one overload which takes a scratch buffer, and another which takes
+// an allocator which is responsible for allocating the scratch space.  In
+// theory the second one shouldn't be necessary -- users of this function could
+// just ask cudnn how much scratch space it needs for a particular convolution.
+// But in practice, StreamExecutor does not expose such an API, and in the name
+// of parsimony, perhaps it's better not to add it.  Instead, the first time you
+// call a convolution, you should call the version that takes a scratch
+// allocator and take note of how much memory is used.  The next time you call
+// the same conv, you can provide an explicitly preallocated scratch buffer of
+// that size, if you like.
+Status RunCudnnConvolution(
+    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+    const Shape& output_shape,
+    perftools::gputools::DeviceMemory<float> input_buf,
+    perftools::gputools::DeviceMemory<float> filter_buf,
+    perftools::gputools::DeviceMemory<float> output_buf,
+    perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
+    const ConvolutionDimensionNumbers& dnums,
+    perftools::gputools::dnn::AlgorithmConfig algorithm,
+    perftools::gputools::Stream* stream,
+    perftools::gputools::dnn::ProfileResult* profile_result = nullptr);
+
+Status RunCudnnConvolution(
+    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+    const Shape& output_shape,
+    perftools::gputools::DeviceMemory<float> input_buf,
+    perftools::gputools::DeviceMemory<float> filter_buf,
+    perftools::gputools::DeviceMemory<float> output_buf,
+    perftools::gputools::ScratchAllocator* scratch_allocator,
+    const Window& window, const ConvolutionDimensionNumbers& dnums,
+    perftools::gputools::dnn::AlgorithmConfig algorithm,
+    perftools::gputools::Stream* stream,
+    perftools::gputools::dnn::ProfileResult* profile_result = nullptr);
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
index 07543d4..12ec266 100644 (file)
@@ -35,8 +35,9 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/call_inliner.h"
 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
-#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@@ -127,7 +128,9 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
 }
 
 // Runs optimization passes on the given HLO module.
-tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
+tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
+                                     se::StreamExecutor* stream_exec,
+                                     DeviceMemoryAllocator* device_allocator) {
   {
     HloPassPipeline pipeline("optimization");
     pipeline.AddInvariantChecker<HloVerifier>();
@@ -143,6 +146,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
     // most ops.
     pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
     pipeline.AddPass<DotDecomposer>();
+
     {
       auto& pass =
           pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
@@ -173,7 +177,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
       pass.AddPass<ReshapeMover>();
       pass.AddPass<HloConstantFolding>();
     }
-    pipeline.AddPass<ConvolutionFolding>();
+
     pipeline.AddPass<TransposeFolding>(
         [](const HloInstruction& dot,
            const TransposeFolding::OperandIndices& candidate_operands) {
@@ -185,6 +189,58 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
     pipeline.AddPass<HloDCE>();
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
+
+  {
+    // Convert convolutions into CustomCalls to cudnn, then canonicalize them
+    // (PadInsertion).
+    HloPassPipeline pipeline("conv_canonicalization");
+    pipeline.AddInvariantChecker<HloVerifier>();
+    pipeline.AddPass<CudnnConvolutionRewriter>();
+    pipeline.AddPass<PadInsertion>();
+
+    // Choose the fastest algorithm for each conv.
+    //
+    // In theory doing this here is way too early: It needs to happen after
+    // layout assignment, because the layout of the inputs/outputs affects the
+    // speed of the conv.  But currently we only allow only one input/output
+    // layout when calling cudnn, so there's no ambiguity.
+    //
+    // We pick the algorithm at this early stage so we can generate better HLO.
+    // After CudnnConvolutionRewriter, our convolutions are CustomCalls which
+    // return a tuple (conv_result, scratch_memory), and the each conv uses 0
+    // bytes of scratch:
+    //
+    //   customcall = (f32[...], f32[0])
+    //   return gte(customcall, 0)
+    //
+    // The algorithm picker then chooses the best algorithm, and potentially
+    // increases the scratch space.  It replaces customcall with new_tuple,
+    // giving us the following:
+    //
+    //   new_customcall = (f32[...], f32[N])
+    //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
+    //   return gte(new_tuple, 0)
+    //
+    // The new tuple and gte instructions then be simplified away, because
+    // nobody is expected to use the scratch value.
+    //
+    // However, if we were to run CudnnConvolutionAlgorithmPicker after layout
+    // assignment, fusion would already have run, and the gte(customcall, 0)
+    // would probably already be into a fusion node.  We can't simplify across
+    // HloComputation boundaries, so in this case we wouldn't be able to
+    // simplify away the new_tuple bits.
+    //
+    // We'll need to revisit this if we ever allow multiple layouts for the
+    // inputs/outputs of a cudnn convolution.
+    pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
+                                                      device_allocator);
+    // Clean up new_tuple described above.
+    pipeline.AddPass<TupleSimplifier>();
+    pipeline.AddPass<HloDCE>();
+
+    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+  }
+
   {
     HloPassFix<HloPassPipeline> fusion("fusion");
     fusion.AddInvariantChecker<HloVerifier>();
@@ -212,9 +268,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
 
 // Modifies the given HLO module so that it will be accepted by IrEmitter.
 // Unlike optimization passes, the passes are necessary for correctness.
-tensorflow::Status PrepareHloModuleForIrEmitting(
-    HloModule* hlo_module, se::StreamExecutor* stream_exec,
-    DeviceMemoryAllocator* /*device_allocator*/) {
+tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
   // In some cases, we have to place the result of an instruction in a temporary
   // buffer. For instance, the buffer that holds an external parameter is
   // assumed immutable at this point, and should not be reused for output
@@ -222,9 +276,10 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
   // the parameter.
   HloPassPipeline pipeline("GPU-ir-emit-prepare");
   pipeline.AddInvariantChecker<HloVerifier>();
-  pipeline.AddPass<PadInsertion>();
+
   pipeline.AddPass<GpuLayoutAssignment>(
       hlo_module->mutable_entry_computation_layout());
+
   // The LayoutAssignment pass may leave behind kCopy instructions which are
   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
@@ -417,7 +472,8 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
   Tracing::TraceMe annotation("HLO Transforms", module->name(),
                               /*is_expensive=*/true);
-  TF_RETURN_IF_ERROR(OptimizeHloModule(module.get()));
+  TF_RETURN_IF_ERROR(
+      OptimizeHloModule(module.get(), stream_exec, device_allocator));
   return std::move(module);
 }
 
@@ -428,8 +484,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
 
   TF_RET_CHECK(stream_exec != nullptr);
 
-  TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get(), stream_exec,
-                                                   device_allocator));
+  TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
 
   llvm::LLVMContext llvm_context;
   std::string buffer;
@@ -464,8 +519,9 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
                           /*color_alignment=*/[](LogicalBuffer::Color) {
                             return kCudaMallocAlignBytes;
                           }));
-  // BufferAssignment::ToString() includes a header, so no need for us to
-  // print one ourselves.
+  // BufferAssignment::Stats::ToString() and BufferAssignment::ToString()
+  // include headers, so no need for us to print them ourselves.
+  XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString());
   XLA_VLOG_LINES(2, buffer_assignment->ToString());
   XLA_VLOG_LINES(2, module->ToString());
   const string xla_dump_optimized_hlo_proto_to =
index e3b493c..88bf5a7 100644 (file)
@@ -78,6 +78,12 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
       for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
         TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
       }
+    } else if (IsCustomCallToDnnConvolution(*hlo)) {
+      // The last argument to a CUDNN convolution is its algorithm, which must
+      // be an HLO constant -- it shouldn't be copied.
+      for (int64 i = 0; i < hlo->operand_count() - 1; ++i) {
+        TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
+      }
     } else if (ImplementedAsLibraryCall(*hlo)) {
       // For all other library calls, materialize all the operands into memory.
       for (int64 i = 0; i < hlo->operand_count(); ++i) {
index 58915f1..89f1e62 100644 (file)
@@ -28,122 +28,114 @@ limitations under the License.
 namespace xla {
 namespace gpu {
 
+// cuDNN convolutions are called with specific layouts on the input, output,
+// and filter:
+//
+//   input: DataLayout::kBatchDepthYX
+//   output: DataLayout::kBatchDepthYX
+//   filter: FilterLayout::kOutputInputYX
+//
+// The order dimensions in the constant name is major-to-minor (eg, the
+// most-major dimension of the input is batch, most-minor is X). The
+// specific dimension numbers these named dimensions correspond to is
+// determined by the ConvolutionDimensionNumbers argument. Y is spatial
+// dimension 0, and X is spatial dimension 1.
+//
+// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls.
+static Status AddBackendConstraintsToDnnConvCustomCall(
+    HloInstruction* instr, LayoutConstraints* constraints) {
+  CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
+  Shape input_shape;
+  Shape filter_shape;
+  Shape output_shape;
+  const auto& target = instr->custom_call_target();
+  if (target == kCudnnConvForwardCallTarget) {
+    input_shape = instr->operand(0)->shape();
+    filter_shape = instr->operand(1)->shape();
+    output_shape = instr->shape().tuple_shapes(0);
+  } else if (target == kCudnnConvBackwardInputCallTarget) {
+    input_shape = instr->shape().tuple_shapes(0);
+    filter_shape = instr->operand(1)->shape();
+    output_shape = instr->operand(0)->shape();
+  } else if (target == kCudnnConvBackwardFilterCallTarget) {
+    input_shape = instr->operand(0)->shape();
+    filter_shape = instr->shape().tuple_shapes(0);
+    output_shape = instr->operand(1)->shape();
+  } else {
+    LOG(FATAL) << "Unexpected custom call target: "
+               << instr->custom_call_target();
+  }
+
+  // Construct minor-to-major dimension orders for operands and result.
+  // cuDNN's convolution APIs support the BDYX layout for activations/output
+  // and the OIYX layout for weights.
+  // TODO(b/29399649): Be more flexible about handling layouts of cuDNN
+  // calls after we switch to cuDNN v5.
+  const ConvolutionDimensionNumbers& dimension_numbers =
+      instr->convolution_dimension_numbers();
+  std::vector<int64> input_layout;
+  for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0;
+       --i) {
+    input_layout.push_back(dimension_numbers.input_spatial_dimensions(i));
+  }
+  input_layout.push_back(dimension_numbers.input_feature_dimension());
+  input_layout.push_back(dimension_numbers.input_batch_dimension());
+  *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout);
+
+  std::vector<int64> filter_layout;
+  for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0;
+       --i) {
+    filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i));
+  }
+  filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension());
+  filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension());
+  *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout);
+
+  std::vector<int64> output_layout;
+  for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0;
+       --i) {
+    output_layout.push_back(dimension_numbers.output_spatial_dimensions(i));
+  }
+  output_layout.push_back(dimension_numbers.output_feature_dimension());
+  output_layout.push_back(dimension_numbers.output_batch_dimension());
+  *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout);
+
+  // The custom call returns a tuple of (actual_result, scratch_buffer);
+  // call_result_buf is the logical buffer for actual_result, the thing that
+  // contains the result of the conv call.
+  TF_ASSIGN_OR_RETURN(const LogicalBuffer* call_result_buf,
+                      constraints->points_to_analysis().GetBufferDefinedAt(
+                          instr, /*index=*/{0}));
+
+  // Set layouts of the instructions' shapes.
+  if (target == kCudnnConvForwardCallTarget) {
+    TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
+    TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
+    TF_RETURN_IF_ERROR(
+        constraints->SetBufferLayout(output_shape.layout(), *call_result_buf));
+  } else if (target == kCudnnConvBackwardInputCallTarget) {
+    TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0));
+    TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
+    TF_RETURN_IF_ERROR(
+        constraints->SetBufferLayout(input_shape.layout(), *call_result_buf));
+  } else if (target == kCudnnConvBackwardFilterCallTarget) {
+    TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
+    TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1));
+    TF_RETURN_IF_ERROR(
+        constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf));
+  } else {
+    LOG(FATAL) << "Unexpected custom call target: "
+               << instr->custom_call_target();
+  }
+  return Status::OK();
+}
+
 Status GpuLayoutAssignment::AddBackendConstraints(
     LayoutConstraints* constraints) {
   for (auto* instruction : constraints->computation()->instructions()) {
-    // cuDNN is called with specific layouts on the input, output, and filter:
-    //
-    //   input: DataLayout::kBatchDepthYX
-    //   output: DataLayout::kBatchDepthYX
-    //   filter: FilterLayout::kOutputInputYX
-    //
-    // The order dimensions in the constant name is major-to-minor (eg, the
-    // most-major dimension of the input is batch, most-minor is X). The
-    // specific dimension numbers these named dimensions correspond to is
-    // determined by the ConvolutionDimensionNumbers argument. Y is spatial
-    // dimension 0, and X is spatial dimension 1.
-    //
-    // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls.
-    if (ImplementedAsDnnConvolution(*instruction)) {
-      HloInstruction* input = nullptr;
-      HloInstruction* filter = nullptr;
-      HloInstruction* output = nullptr;
-      if (instruction->opcode() == HloOpcode::kConvolution) {
-        input = instruction->mutable_operand(0);
-        filter = instruction->mutable_operand(1);
-        output = instruction;
-      } else {
-        CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
-        switch (instruction->fusion_kind()) {
-          case HloInstruction::FusionKind::kConvBackwardFilter:
-            // filter = BackwardFilterConvolve(input, output)
-            input = instruction->mutable_operand(0);
-            filter = instruction;
-            output = instruction->mutable_operand(1);
-            break;
-          case HloInstruction::FusionKind::kConvBackwardInput:
-            // input = BackwardInputConvolve(output, filter)
-            input = instruction;
-            filter = instruction->mutable_operand(1);
-            output = instruction->mutable_operand(0);
-            break;
-          default:
-            LOG(FATAL) << "Not a convolution-fusion";
-        }
-      }
-
-      // Construct minor-to-major dimension orders for operands and result.
-      // cuDNN's convolution APIs support the BDYX layout for activations/output
-      // and the OIYX layout for weights.
-      // TODO(b/29399649): Be more flexible about handling layouts of cuDNN
-      // calls after we switch to cuDNN v5.
-      const ConvolutionDimensionNumbers& dimension_numbers =
-          instruction->convolution_dimension_numbers();
-      std::vector<int64> input_layout;
-      for (int i = dimension_numbers.input_spatial_dimensions_size() - 1;
-           i >= 0; --i) {
-        input_layout.push_back(dimension_numbers.input_spatial_dimensions(i));
-      }
-      input_layout.push_back(dimension_numbers.input_feature_dimension());
-      input_layout.push_back(dimension_numbers.input_batch_dimension());
-      Shape input_shape(input->shape());
-      *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout);
-
-      std::vector<int64> filter_layout;
-      for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1;
-           i >= 0; --i) {
-        filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i));
-      }
-      filter_layout.push_back(
-          dimension_numbers.kernel_input_feature_dimension());
-      filter_layout.push_back(
-          dimension_numbers.kernel_output_feature_dimension());
-      Shape filter_shape(filter->shape());
-      *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout);
-
-      std::vector<int64> output_layout;
-      for (int i = dimension_numbers.output_spatial_dimensions_size() - 1;
-           i >= 0; --i) {
-        output_layout.push_back(dimension_numbers.output_spatial_dimensions(i));
-      }
-      output_layout.push_back(dimension_numbers.output_feature_dimension());
-      output_layout.push_back(dimension_numbers.output_batch_dimension());
-      Shape output_shape(output->shape());
-      *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout);
-
-      // Set layouts of the instructions' shapes.
-      if (instruction->opcode() == HloOpcode::kConvolution) {
-        TF_RETURN_IF_ERROR(
-            constraints->SetOperandLayout(input_shape, output, 0));
-        TF_RETURN_IF_ERROR(
-            constraints->SetOperandLayout(filter_shape, output, 1));
-        TF_RETURN_IF_ERROR(
-            constraints->SetInstructionLayout(output_shape, output));
-      } else {
-        CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
-        switch (instruction->fusion_kind()) {
-          case HloInstruction::FusionKind::kConvBackwardFilter:
-            // filter = BackwardFilterConvolve(input, output)
-            TF_RETURN_IF_ERROR(
-                constraints->SetOperandLayout(input_shape, filter, 0));
-            TF_RETURN_IF_ERROR(
-                constraints->SetInstructionLayout(filter_shape, filter));
-            TF_RETURN_IF_ERROR(
-                constraints->SetOperandLayout(output_shape, filter, 1));
-            break;
-          case HloInstruction::FusionKind::kConvBackwardInput:
-            // input = BackwardInputConvolve(output, filter)
-            TF_RETURN_IF_ERROR(
-                constraints->SetInstructionLayout(input_shape, input));
-            TF_RETURN_IF_ERROR(
-                constraints->SetOperandLayout(output_shape, input, 0));
-            TF_RETURN_IF_ERROR(
-                constraints->SetOperandLayout(filter_shape, input, 1));
-            break;
-          default:
-            LOG(FATAL) << "Not a convolution-fusion";
-        }
-      }
+    if (IsCustomCallToDnnConvolution(*instruction)) {
+      TF_RETURN_IF_ERROR(
+          AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
     }
   }
   return Status::OK();
@@ -151,9 +143,12 @@ Status GpuLayoutAssignment::AddBackendConstraints(
 
 bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout(
     const HloInstruction* instruction) {
-  // Inputs to cudnn batchnorm custom calls don't need the major-first layout
-  // (i.e. {n, n-1, ...0}) -- we can handle any layout.
-  return !IsCustomCallToDnnBatchNorm(*instruction);
+  // - Inputs to cudnn batchnorm custom calls don't need the major-first layout
+  //   (i.e. {n, n-1, ...0}) -- we can handle any layout.
+  // - Inputs to cudnn convolution require custom layouts handled in
+  //   AddBackendConstraints.
+  return !IsCustomCallToDnnBatchNorm(*instruction) &&
+         !IsCustomCallToDnnConvolution(*instruction);
 }
 
 Status GpuLayoutAssignment::PropagateOperandConstraint(
index 1d47ffd..2d6dad2 100644 (file)
@@ -137,49 +137,6 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
                    .ValueOrDie());
 }
 
-TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) {
-  HloComputation::Builder builder(TestName());
-  auto input = builder.AddInstruction(HloInstruction::CreateParameter(
-      0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input"));
-  auto filter = builder.AddInstruction(HloInstruction::CreateParameter(
-      1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter"));
-
-  Window conv_window;
-  WindowDimension* conv_window_row = conv_window.add_dimensions();
-  conv_window_row->set_size(1);
-  WindowDimension* conv_window_col = conv_window.add_dimensions();
-  conv_window_col->set_size(2);
-  conv_window_col->set_padding_high(1);
-
-  ConvolutionDimensionNumbers conv_dnums;
-  conv_dnums.set_input_batch_dimension(0);
-  conv_dnums.set_output_batch_dimension(0);
-  conv_dnums.set_input_feature_dimension(1);
-  conv_dnums.set_output_feature_dimension(1);
-  conv_dnums.add_input_spatial_dimensions(2);
-  conv_dnums.add_output_spatial_dimensions(2);
-  conv_dnums.add_input_spatial_dimensions(3);
-  conv_dnums.add_output_spatial_dimensions(3);
-  conv_dnums.set_kernel_output_feature_dimension(0);
-  conv_dnums.set_kernel_input_feature_dimension(1);
-  conv_dnums.add_kernel_spatial_dimensions(2);
-  conv_dnums.add_kernel_spatial_dimensions(3);
-
-  auto conv = builder.AddInstruction(
-      HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}),
-                                     input, filter, conv_window, conv_dnums));
-  auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0}));
-  builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose));
-
-  auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
-                   .Run(module.get())
-                   .ValueOrDie());
-}
-
 TEST_F(InstructionFusionTest, GetTupleElementFused) {
   HloComputation::Builder builder(TestName());
   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
index 76566a9..2f65edf 100644 (file)
@@ -90,43 +90,6 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
   return false;
 }
 
-bool ImplementedAsDnnConvolution(const HloInstruction& hlo) {
-  // We can only do this if the HLO is unnested.
-  if (hlo.parent() != hlo.GetModule()->entry_computation()) {
-    return false;
-  }
-
-  // Forward convolution.
-  if (hlo.opcode() == HloOpcode::kConvolution) {
-    const ConvolutionDimensionNumbers& dnums =
-        hlo.convolution_dimension_numbers();
-    if (dnums.input_spatial_dimensions_size() > 3) {
-      return false;
-    }
-
-    // CuDNN does not accept zero-element arguments
-    if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) ||
-        ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) {
-      return false;
-    }
-
-    if (window_util::HasWindowReversal(hlo.window())) {
-      return false;
-    }
-
-    return true;
-  }
-
-  // Backward convolution.
-  if (hlo.opcode() == HloOpcode::kFusion &&
-      (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter ||
-       hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) {
-    return true;
-  }
-
-  return false;
-}
-
 const char* const kCudnnBatchNormForwardInferenceCallTarget =
     "__cudnn$batchNormalizationForwardInference";
 const char* const kCudnnBatchNormForwardTrainingCallTarget =
@@ -144,9 +107,76 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
          target == kCudnnBatchNormBackwardCallTarget;
 }
 
+const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
+const char* const kCudnnConvBackwardInputCallTarget =
+    "__cudnn$convBackwardInput";
+const char* const kCudnnConvBackwardFilterCallTarget =
+    "__cudnn$convBackwardFilter";
+
+bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
+  if (hlo.opcode() != HloOpcode::kCustomCall) {
+    return false;
+  }
+  const auto& target = hlo.custom_call_target();
+  return target == kCudnnConvForwardCallTarget ||
+         target == kCudnnConvBackwardInputCallTarget ||
+         target == kCudnnConvBackwardFilterCallTarget;
+}
+
 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
-  return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo) ||
-         IsCustomCallToDnnBatchNorm(hlo);
+  return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
+         IsCustomCallToDnnConvolution(hlo);
+}
+
+static HloInstruction* CreateCudnnConv(
+    const char* call_target, const Shape& shape, HloInstruction* lhs,
+    HloInstruction* rhs, const Window& window,
+    const ConvolutionDimensionNumbers& dnums) {
+  HloComputation* computation = lhs->parent();
+
+  // This call returns a tuple of (conv_result, scratch_memory), where
+  // conv_result is the actual result of the convolution, and scratch_memory is
+  // temporary memory used by cudnn.
+  //
+  // At the moment, we don't know how much scratch memory this conv is going to
+  // use, so we put u8[0] in this place.  Later on another pass will choose
+  // which conv algorithm to use, and at that point we'll modify the shape of
+  // this second tuple element.
+  Shape call_shape =
+      ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+  // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn
+  // algorithm to use.  It's up to a later pass to choose the algorithm, so to
+  // indicate that we haven't yet made a choice, we speicfy -1 for that arg.
+  HloInstruction* negative_one = computation->AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<int64>(-1)));
+  HloInstruction* custom_call =
+      computation->AddInstruction(HloInstruction::CreateCustomCall(
+          call_shape, {lhs, rhs, negative_one}, call_target));
+  custom_call->set_window(window);
+  custom_call->set_convolution_dimension_numbers(dnums);
+  return custom_call;
+}
+
+HloInstruction* CreateCudnnConvForward(
+    const Shape& shape, HloInstruction* input, HloInstruction* kernel,
+    const Window& window, const ConvolutionDimensionNumbers& dnums) {
+  return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
+                         window, dnums);
+}
+
+HloInstruction* CreateCudnnConvBackwardInput(
+    const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
+    const Window& window, const ConvolutionDimensionNumbers& dnums) {
+  return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
+                         reverse_filter, window, dnums);
+}
+
+HloInstruction* CreateCudnnConvBackwardFilter(
+    const Shape& shape, HloInstruction* input, HloInstruction* output,
+    const Window& window, const ConvolutionDimensionNumbers& dnums) {
+  return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
+                         output, window, dnums);
 }
 
 bool IsReductionToVector(const HloInstruction& reduce) {
index d24ed98..7ad9680 100644 (file)
@@ -22,6 +22,9 @@ limitations under the License.
 #include "llvm/IR/Value.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 
+// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
+// don't belong in "ir_emission_utils".
+
 namespace xla {
 namespace gpu {
 
@@ -30,9 +33,6 @@ constexpr int64 kWarpSize = 32;
 // Returns true if `hlo` will be implemented as a call to BLAS gemm.
 bool ImplementedAsGemm(const HloInstruction& hlo);
 
-// Returns true if `hlo` will be implemented as a call to cuDNN convolution.
-bool ImplementedAsDnnConvolution(const HloInstruction& hlo);
-
 // A call to cuDNN for batch normalization is represented as CustomCall HLO with
 // a call target equal to one of these strings.
 //
@@ -58,6 +58,60 @@ extern const char* const kCudnnBatchNormBackwardCallTarget;
 // sequence of generic HLOs or to a cuDNN CustomCall.
 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
 
+// A call to cuDNN for convolution (forward, backward filter, or backward input)
+// is represented as a CustomCall HLO with a call target equal to one of these
+// strings.
+//
+// These CustomCalls have window() and convolution_dimension_numbers() set like
+// regular convolution ops.  They have the same LHS and RHS operands, plus one
+// additional int64 operand, representing which cudnn algorithm to run.  This
+// operand must be an HLO constant.  A value of -1 means that the implementation
+// is free to choose the best algorithm it can.
+//
+// These calls output a tuple (conv_result, scratch_memory), where conv_result
+// is the actual result of the convolution, and scratch_memory is temporary
+// memory used by cudnn.  Callers shouldn't inspect scratch_memory, as its value
+// is not well-defined.
+//
+// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls.
+// When it does so, it chooses algorithm -1 and 0 bytes of scratch space.  Later
+// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit
+// algorithm for each conv and sets the amount of scratch space needed.
+//
+// (Representing the scratch memory as an output may seem strange at first, but
+// it's quite sensible, from a certain point of view.  The scratch buffer is a
+// location in memory that the conv can write into, but which it can't legally
+// read from, at least until it's written something first.  But that's exactly
+// the definition of an output buffer.)
+extern const char* const kCudnnConvForwardCallTarget;
+extern const char* const kCudnnConvBackwardInputCallTarget;
+extern const char* const kCudnnConvBackwardFilterCallTarget;
+
+// Returns true if `hlo` will be implemented as a call to a cuDNN convolution
+// routine.
+//
+// This returns true if `hlo` is a CustomCall HLO with a call target equal to
+// one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a
+// kConvolution opcode.
+bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
+
+// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv.
+// Note that these CustomCalls return a tuple (conv_result, scratch_memory).  If
+// you want just the conv result, you'll need to get-tuple-element the value
+// returned by this function.
+//
+// The created cudnn call will use the default cudnn algorithm and no scratch
+// space.
+HloInstruction* CreateCudnnConvForward(
+    const Shape& shape, HloInstruction* input, HloInstruction* kernel,
+    const Window& window, const ConvolutionDimensionNumbers& dnums);
+HloInstruction* CreateCudnnConvBackwardInput(
+    const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
+    const Window& window, const ConvolutionDimensionNumbers& dnums);
+HloInstruction* CreateCudnnConvBackwardFilter(
+    const Shape& shape, HloInstruction* input, HloInstruction* output,
+    const Window& window, const ConvolutionDimensionNumbers& dnums);
+
 // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
 // or cuDNN convolution.
 bool ImplementedAsLibraryCall(const HloInstruction& hlo);
index 3aa1784..9031a83 100644 (file)
@@ -336,9 +336,6 @@ class IrEmitterUnnested : public IrEmitter {
   // Thunk object.
   std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
 
-  // Returns a ConvolutionThunk that calls DNN to implement `inst`.
-  std::unique_ptr<Thunk> BuildConvolutionThunk(const HloInstruction* inst);
-
   // Returns a FftThunk that calls cuFFT to implement `inst`.
   std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
 
index bd428f8..4072573 100644 (file)
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
@@ -278,10 +279,6 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
 }
 
 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
-  if (ImplementedAsDnnConvolution(*convolution)) {
-    thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution));
-    return Status::OK();
-  }
   thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
   return IrEmitter::HandleConvolution(convolution);
 }
@@ -380,6 +377,71 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
     return Status::OK();
   }
 
+  if (IsCustomCallToDnnConvolution(*custom_call)) {
+    const auto& assn = ir_emitter_context_->buffer_assignment();
+    const auto& lhs_shape = custom_call->operand(0)->shape();
+    const auto& rhs_shape = custom_call->operand(1)->shape();
+    const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+    auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
+    auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+    auto tuple_result_slice = GetAllocationSlice(*custom_call);
+    auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
+    auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
+
+    const HloInstruction* algorithm_inst = custom_call->operand(2);
+    CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
+    int64 algorithm = algorithm_inst->literal().Get<int64>({});
+
+    const auto& target = custom_call->custom_call_target();
+    std::unique_ptr<ConvolutionThunk> thunk;
+    if (target == kCudnnConvForwardCallTarget) {
+      thunk = MakeUnique<ConvolutionThunk>(
+          CudnnConvKind::kForward,
+          /*input_buffer=*/lhs_slice,
+          /*filter_buffer=*/rhs_slice,
+          /*output_buffer=*/conv_result_slice,
+          /*tuple_result_buffer=*/tuple_result_slice,
+          /*scratch_buffer=*/scratch_slice,
+          /*input_shape=*/lhs_shape,
+          /*filter_shape=*/rhs_shape,
+          /*output_shape=*/conv_result_shape,  //
+          custom_call->window(), custom_call->convolution_dimension_numbers(),
+          algorithm, custom_call);
+    } else if (target == kCudnnConvBackwardInputCallTarget) {
+      thunk = MakeUnique<ConvolutionThunk>(
+          CudnnConvKind::kBackwardInput,
+          /*input_buffer=*/conv_result_slice,
+          /*filter_buffer=*/rhs_slice,
+          /*output_buffer=*/lhs_slice,
+          /*tuple_result_buffer=*/tuple_result_slice,
+          /*scratch_buffer=*/scratch_slice,
+          /*input_shape=*/conv_result_shape,
+          /*filter_shape=*/rhs_shape,
+          /*output_shape=*/lhs_shape,  //
+          custom_call->window(), custom_call->convolution_dimension_numbers(),
+          algorithm, custom_call);
+    } else if (target == kCudnnConvBackwardFilterCallTarget) {
+      thunk = MakeUnique<ConvolutionThunk>(
+          CudnnConvKind::kBackwardFilter,
+          /*input_buffer=*/lhs_slice,
+          /*filter_buffer=*/conv_result_slice,
+          /*output_buffer=*/rhs_slice,
+          /*tuple_result_buffer=*/tuple_result_slice,
+          /*scratch_buffer=*/scratch_slice,
+          /*input_shape=*/lhs_shape,
+          /*filter_shape=*/conv_result_shape,
+          /*output_shape=*/rhs_shape,  //
+          custom_call->window(), custom_call->convolution_dimension_numbers(),
+          algorithm, custom_call);
+    } else {
+      LOG(FATAL) << "Unexpected custom call target: "
+                 << custom_call->custom_call_target();
+    }
+
+    thunk_sequence_->emplace_back(std::move(thunk));
+    return Status::OK();
+  }
+
   return IrEmitter::HandleCustomCall(custom_call);
 }
 
@@ -500,10 +562,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
     thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
     return Status::OK();
   }
-  if (ImplementedAsDnnConvolution(*fusion)) {
-    thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion));
-    return Status::OK();
-  }
   thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
   return IrEmitter::HandleFusion(fusion);
 }
@@ -2011,52 +2069,6 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
   LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
 }
 
-std::unique_ptr<Thunk> IrEmitterUnnested::BuildConvolutionThunk(
-    const HloInstruction* inst) {
-  const HloInstruction* lhs = inst->operand(0);
-  const HloInstruction* rhs = inst->operand(1);
-  if (inst->opcode() == HloOpcode::kConvolution) {
-    // Forward covolution.
-    return MakeUnique<ConvolutionThunk>(
-        ConvolutionThunk::ConvolutionKind::kForward,
-        /*input_buffer=*/GetAllocationSlice(*lhs),
-        /*filter_buffer=*/GetAllocationSlice(*rhs),
-        /*output_buffer=*/GetAllocationSlice(*inst),
-        /*input_shape=*/lhs->shape(),
-        /*filter_shape=*/rhs->shape(),
-        /*output_shape=*/inst->shape(), inst->window(),
-        inst->convolution_dimension_numbers(), inst);
-  }
-
-  // Backward filter convolution, which takes the input (activations) and the
-  // gradients, and computes the filter.
-  CHECK_EQ(HloOpcode::kFusion, inst->opcode());
-  switch (inst->fusion_kind()) {
-    case HloInstruction::FusionKind::kConvBackwardFilter:
-      return MakeUnique<ConvolutionThunk>(
-          ConvolutionThunk::ConvolutionKind::kBackwardFilter,
-          /*input_buffer=*/GetAllocationSlice(*lhs),
-          /*filter_buffer=*/GetAllocationSlice(*inst),
-          /*output_buffer=*/GetAllocationSlice(*rhs),
-          /*input_shape=*/lhs->shape(),
-          /*filter_shape=*/inst->shape(),
-          /*output_shape=*/rhs->shape(), inst->window(),
-          inst->convolution_dimension_numbers(), inst);
-    case HloInstruction::FusionKind::kConvBackwardInput:
-      return MakeUnique<ConvolutionThunk>(
-          ConvolutionThunk::ConvolutionKind::kBackwardInput,
-          /*input_buffer=*/GetAllocationSlice(*inst),
-          /*filter_buffer=*/GetAllocationSlice(*rhs),
-          /*output_buffer=*/GetAllocationSlice(*lhs),
-          /*input_shape=*/inst->shape(),
-          /*filter_shape=*/rhs->shape(),
-          /*output_shape=*/lhs->shape(), inst->window(),
-          inst->convolution_dimension_numbers(), inst);
-    default:
-      LOG(FATAL) << "Not a convolution-fusion";
-  }
-}
-
 std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
     const HloInstruction* inst) {
   const HloInstruction* operand = inst->operand(0);
index 2923a79..25846dc 100644 (file)
@@ -27,7 +27,7 @@ namespace gpu {
 
 namespace {
 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
-  CHECK_EQ(HloOpcode::kConvolution, conv.opcode());
+  CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
   return window_util::HasSymmetricPadding(conv.window()) &&
          !window_util::HasNegativePadding(conv.window()) &&
          !window_util::HasDilation(conv.window());
@@ -47,6 +47,12 @@ HloInstruction* MaybePaddedAndSlicedInput(
       window_util::HasBaseDilation(conv_window)) {
     // If padding is uneven or has dilation, we insert a kPad instruction that
     // applies positive padding and dilation.
+    //
+    // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
+    // moving all the padding into an explicit pad op, we should keep as much
+    // padding inside of cudnn as possible, on the assumption that padding
+    // within cudnn is basically free, whereas a kPad's cost increases as the
+    // amount of padding increases.
     PaddingConfig padding_config =
         MakeNoPaddingConfig(input->shape().dimensions_size());
     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
@@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
     dim->set_window_dilation(1);
   }
 
+  // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
+  // out the shape of conv_result.
+  Shape old_conv_shape = conv->shape().tuple_shapes(0);
+
   VLOG(1) << "Canonicalizing forward conv";
-  auto new_conv = HloInstruction::CreateConvolve(
-      conv->shape(), new_input, new_kernel, new_conv_window,
-      conv->convolution_dimension_numbers());
+  auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
+                                         new_conv_window,
+                                         conv->convolution_dimension_numbers());
   VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
           << new_conv->ToString();
-  TF_CHECK_OK(
-      conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv)));
+  TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
   return true;
 }
 
@@ -190,6 +199,8 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
 
 bool PadInsertion::CanonicalizeBackwardFilterConvolution(
     HloInstruction* backward_conv) {
+  CHECK_EQ(backward_conv->custom_call_target(),
+           kCudnnConvBackwardFilterCallTarget);
   if (window_util::HasSymmetricPadding(backward_conv->window())) {
     return false;
   }
@@ -202,15 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
   //   ABCD0 = Pad(ABCD, padding_high=1)
   //   BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
   // We choose the lesser of padding_low and padding_high as the new padding.
-  HloInstruction* forward_conv = backward_conv->fused_expression_root();
   HloInstruction* input = backward_conv->mutable_operand(0);
-  Window new_forward_conv_window = forward_conv->window();
   Window new_backward_conv_window = backward_conv->window();
   // input_padding_config is the config of the kPad to be inserted.
   PaddingConfig input_padding_config =
       MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
-  ConvolutionDimensionNumbers forward_conv_dnums =
-      forward_conv->convolution_dimension_numbers();
   ConvolutionDimensionNumbers backward_conv_dnums =
       backward_conv->convolution_dimension_numbers();
   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
@@ -222,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
       // cuDNN convolution (which doesn't support negative padding) to fail.
       return false;
     }
-    // If the backward convolution has uneven padding on the activations, we
-    // move some padding on the larger end to "internal" padding, so that the
-    // backward convolution produces larger weight gradients which get sliced
-    // later. Therefore, the amount of new padding (low or high) is the minimum
-    // of the amount of old padding low and old padding high.
+    // Compute the new, even padding for the backward conv operation.
     int64 new_conv_padding = std::min(padding_low, padding_high);
     int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
     input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
@@ -237,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
     // Since we move some padding from the backward convolution to the kPad, we
     // need to accordingly reduce the padding amount of the backward convolution
     // and its inner forward convolution.
-    IncreasePaddingLowBy(-(padding_low - new_conv_padding),
-                         new_backward_conv_window.mutable_dimensions(i));
-    IncreasePaddingHighBy(-(padding_high - new_conv_padding),
-                          new_backward_conv_window.mutable_dimensions(i));
-    IncreasePaddingLowBy(-(padding_low - new_conv_padding),
-                         new_forward_conv_window.mutable_dimensions(i));
-    IncreasePaddingHighBy(-(padding_high - new_conv_padding),
-                          new_forward_conv_window.mutable_dimensions(i));
+    auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
+    new_dim->set_padding_low(new_conv_padding);
+    new_dim->set_padding_high(new_conv_padding);
   }
 
   // Create a new backward convolution replacing the old one.
@@ -260,19 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
               .ConsumeValueOrDie(),
           input, padding, input_padding_config));
 
-  HloInstruction* new_forward_conv =
-      computation->AddInstruction(HloInstruction::CreateConvolve(
-          ShapeInference::InferConvolveShape(
-              padded_input->shape(), output->shape(), new_forward_conv_window,
-              forward_conv_dnums)
-              .ConsumeValueOrDie(),
-          padded_input, output, new_forward_conv_window, forward_conv_dnums));
-
-  // Fuse the new forward convolution to the new backward convolution.
-  HloInstruction* new_backward_conv =
-      computation->CreateFusionInstructionForBackwardConvolution(
-          {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter,
-          new_backward_conv_window, backward_conv_dnums);
+  // The shape of the backward_conv CustomCall is a tuple (conv_result,
+  // scratch_buffer).  Extract out the shape of conv_result.
+  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
+  HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
+      backward_conv_shape, padded_input, output, new_backward_conv_window,
+      backward_conv_dnums);
 
   VLOG(1) << "Canonicalizing backward filter conv";
   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
@@ -289,14 +280,15 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
     return false;
   }
 
-  HloInstruction* forward_conv = backward_conv->fused_expression_root();
-  HloInstruction* reverse_filter = forward_conv->mutable_operand(1);
-  Window new_forward_conv_window = forward_conv->window();
   Window new_backward_conv_window = backward_conv->window();
-  ConvolutionDimensionNumbers forward_conv_dnums =
-      forward_conv->convolution_dimension_numbers();
   ConvolutionDimensionNumbers backward_conv_dnums =
       backward_conv->convolution_dimension_numbers();
+
+  // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
+  // Get the shape of conv_result.
+  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
+
+  Shape new_backward_conv_shape = backward_conv_shape;
   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
@@ -315,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
     // where the amount of padding low is larger, we can canonicalize it to
     //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
     //   [A] = Slice([B A])
-    // For consistency, we need to increase the low padding of the inner
-    // convolution by 1 as well because the input is larger now.
     if (padding_low > padding_high) {
       IncreasePaddingLowBy(padding_high - padding_low,
                            new_backward_conv_window.mutable_dimensions(i));
-      IncreasePaddingLowBy(padding_low - padding_high,
-                           new_forward_conv_window.mutable_dimensions(i));
     } else if (padding_low < padding_high) {
       IncreasePaddingHighBy(padding_low - padding_high,
                             new_backward_conv_window.mutable_dimensions(i));
-      IncreasePaddingHighBy(padding_high - padding_low,
-                            new_forward_conv_window.mutable_dimensions(i));
     }
+    // Decreasing the padding by X *increases* the size of our output by X.
+    int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
+    new_backward_conv_shape.set_dimensions(
+        dim, new_backward_conv_shape.dimensions(dim) +
+                 std::abs(padding_low - padding_high));
   }
 
   // Create a new backward convolution replacing the old one.
   HloComputation* computation = backward_conv->parent();
   HloInstruction* output = backward_conv->mutable_operand(0);
   HloInstruction* filter = backward_conv->mutable_operand(1);
-  HloInstruction* new_reverse_filter =
-      computation->AddInstruction(HloInstruction::CreateReverse(
-          filter->shape(), filter, reverse_filter->dimensions()));
-  HloInstruction* new_forward_conv =
-      computation->AddInstruction(HloInstruction::CreateConvolve(
-          ShapeInference::InferConvolveShape(
-              output->shape(), new_reverse_filter->shape(),
-              new_forward_conv_window, forward_conv_dnums)
-              .ConsumeValueOrDie(),
-          output, new_reverse_filter, new_forward_conv_window,
-          forward_conv_dnums));
+
+  HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
+      new_backward_conv_shape, output, filter, new_backward_conv_window,
+      backward_conv_dnums);
+
+  // The CustomCall created above returns a tuple (conv_result, scratch_memory).
+  // Extract out the two elements.
   HloInstruction* new_backward_conv =
-      computation->CreateFusionInstructionForBackwardConvolution(
-          {new_forward_conv, new_reverse_filter},
-          HloInstruction::FusionKind::kConvBackwardInput,
-          new_backward_conv_window, backward_conv_dnums);
+      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+          new_backward_conv_shape, new_backward_conv_call, 0));
+  HloInstruction* new_backward_conv_scratch =
+      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+          new_backward_conv_call->shape().tuple_shapes(1),
+          new_backward_conv_call, 1));
 
   // Slice the new backward convolution.
   //
@@ -377,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
   }
 
   // Replace the old backward convolution with the slice.
-  CHECK(ShapeUtil::Compatible(
+  Shape slice_shape =
       ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
                                       limit_indices, strides)
-          .ConsumeValueOrDie(),
-      backward_conv->shape()));
+          .ConsumeValueOrDie();
+  CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
+      << ShapeUtil::HumanString(slice_shape) << " vs "
+      << ShapeUtil::HumanString(backward_conv_shape);
 
-  auto slice =
-      HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv,
-                                  start_indices, limit_indices, strides);
+  HloInstruction* slice = computation->AddInstruction(
+      HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
+                                  start_indices, limit_indices, strides));
+  HloInstruction* new_tuple = computation->AddInstruction(
+      HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
 
   VLOG(1) << "Canonicalizing backward input conv";
   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
-          << slice->ToString();
+          << new_tuple->ToString();
 
-  TF_CHECK_OK(
-      computation->ReplaceWithNewInstruction(backward_conv, std::move(slice)));
+  TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
   return true;
 }
 
@@ -400,18 +392,17 @@ StatusOr<bool> PadInsertion::Run(HloModule* module) {
   bool changed = false;
   for (HloInstruction* instruction :
        module->entry_computation()->MakeInstructionPostOrder()) {
-    if (instruction->opcode() == HloOpcode::kConvolution) {
-      changed |= CanonicalizeForwardConvolution(instruction);
-    } else if (instruction->opcode() == HloOpcode::kFusion) {
-      switch (instruction->fusion_kind()) {
-        case HloInstruction::FusionKind::kConvBackwardFilter:
-          changed |= CanonicalizeBackwardFilterConvolution(instruction);
-          break;
-        case HloInstruction::FusionKind::kConvBackwardInput:
-          changed |= CanonicalizeBackwardInputConvolution(instruction);
-          break;
-        default:
-          break;
+    if (IsCustomCallToDnnConvolution(*instruction)) {
+      const auto& target = instruction->custom_call_target();
+      if (target == kCudnnConvForwardCallTarget) {
+        changed |= CanonicalizeForwardConvolution(instruction);
+      } else if (target == kCudnnConvBackwardFilterCallTarget) {
+        changed |= CanonicalizeBackwardFilterConvolution(instruction);
+      } else if (target == kCudnnConvBackwardInputCallTarget) {
+        changed |= CanonicalizeBackwardInputConvolution(instruction);
+      } else {
+        LOG(FATAL) << "Unknown custom call target for cudnn conv: "
+                   << instruction->ToString();
       }
     }
   }
index a63affa..5432419 100644 (file)
@@ -461,20 +461,6 @@ HloInstruction* HloComputation::CreateFusionInstruction(
   return fusion_instruction;
 }
 
-HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution(
-    tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
-    HloInstruction::FusionKind fusion_kind, const Window& window,
-    const ConvolutionDimensionNumbers& conv_dnums) {
-  CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind ||
-        HloInstruction::FusionKind::kConvBackwardInput == fusion_kind);
-  HloInstruction* root = instructions_to_fuse.front();
-  HloInstruction* fusion_instruction =
-      AddInstruction(HloInstruction::CreateFusionForBackwardConvolution(
-          root->shape(), fusion_kind, window, conv_dnums, root));
-  FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
-  return fusion_instruction;
-}
-
 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
     ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
@@ -577,8 +563,11 @@ Status HloComputation::ReplaceWithNewInstruction(
 
 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
                                           HloInstruction* new_instruction) {
-  TF_RET_CHECK(ShapeUtil::Compatible(old_instruction->shape(),
-                                     new_instruction->shape()));
+  TF_RET_CHECK(
+      ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
+      << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
+      << ShapeUtil::HumanString(new_instruction->shape());
+
   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
            << new_instruction->ToString();
   // Try to add metadata for HLO instructions that are created to replace
index 6436815..061c59a 100644 (file)
@@ -224,15 +224,6 @@ class HloComputation {
       tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
       HloInstruction::FusionKind fusion_kind);
 
-  // Creates a fusion instruction that represents a backward convolution. This
-  // is similar to CreateFusionInstruction but takes window and conv_dnums which
-  // indicate the window and convolution dimension numbers of the backward
-  // convolution.
-  HloInstruction* CreateFusionInstructionForBackwardConvolution(
-      tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
-      HloInstruction::FusionKind fusion_kind, const Window& window,
-      const ConvolutionDimensionNumbers& conv_dnums);
-
   // Create a deep copy of the given instruction and return the instruction
   // producing the copied result. All instructions performing the copy are added
   // to the computation. For array-shaped values, this method trivially returns
index cd54eb7..9cd5a1e 100644 (file)
@@ -469,7 +469,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
 }
 
 Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) {
-  return Unimplemented("Custom-call is not implemented for HLO cost analysis.");
+  // We can't do anything sane with CustomCalls, since we don't know what they
+  // do, and returning an error status will stop iteration over this
+  // computation, which is probably also not what we want.  So just punt and
+  // return OK.  This will cause all of the properties to be reported as 0,
+  // which is fine.
+  current_should_compute_bottleneck_time_ = false;
+  return Status::OK();
 }
 
 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
index a889c35..fac6b43 100644 (file)
@@ -763,16 +763,13 @@ HloInstruction::CreateBroadcastSequence(
   return instruction;
 }
 
-// We put the fusion kind into the instruction's name for transpose-dot and
-// backward-conv fusions, since those fusions are really just describing a type
-// of dot/conv rather than generating a novel computation.
+// We put the fusion kind into the instruction's name for transpose-dot fusions,
+// since those fusions are really just describing a type of dot rather than
+// generating a novel computation.
 static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
   switch (fusion_kind) {
     case HloInstruction::FusionKind::kTransposeDot:
       return "dot_fusion";
-    case HloInstruction::FusionKind::kConvBackwardInput:
-    case HloInstruction::FusionKind::kConvBackwardFilter:
-      return "conv_fusion";
     default:
       return "fusion";
   }
@@ -804,18 +801,6 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
   return instruction;
 }
 
-/* static */ std::unique_ptr<HloInstruction>
-HloInstruction::CreateFusionForBackwardConvolution(
-    const Shape& shape, FusionKind fusion_kind, const Window& window,
-    const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) {
-  std::unique_ptr<HloInstruction> fusion =
-      CreateFusion(shape, fusion_kind, fused_root);
-  fusion->window_ = MakeUnique<Window>(window);
-  fusion->convolution_dimension_numbers_ =
-      MakeUnique<ConvolutionDimensionNumbers>(conv_dnums);
-  return fusion;
-}
-
 void HloInstruction::MergeFusionInstruction(
     HloInstruction* instruction_to_merge) {
   CHECK_EQ(opcode_, HloOpcode::kFusion);
@@ -2318,7 +2303,7 @@ string HloInstruction::ToCategory() const {
     return "data formatting";
   }
 
-  auto conv_category = [&] {
+  if (opcode() == HloOpcode::kConvolution) {
     string category = "convolution";
     if (window_util::HasBaseDilation(window())) {
       category += " base-dilated";
@@ -2327,10 +2312,6 @@ string HloInstruction::ToCategory() const {
       category += " window-dilated";
     }
     return category;
-  };
-
-  if (opcode() == HloOpcode::kConvolution) {
-    return conv_category();
   }
 
   // Give transpose-dot and backwards-conv fusions the categories "dot" and
@@ -2348,9 +2329,6 @@ string HloInstruction::ToCategory() const {
         return "output fusion";
       case FusionKind::kTransposeDot:
         return "dot";
-      case FusionKind::kConvBackwardFilter:
-      case FusionKind::kConvBackwardInput:
-        return conv_category();
       case FusionKind::kCustom:
         return "custom fusion";
     }
@@ -3125,10 +3103,6 @@ string ToString(HloInstruction::FusionKind kind) {
       return "kOutput";
     case HloInstruction::FusionKind::kTransposeDot:
       return "kTransposeDot";
-    case HloInstruction::FusionKind::kConvBackwardFilter:
-      return "kConvBackwardFilter";
-    case HloInstruction::FusionKind::kConvBackwardInput:
-      return "kConvBackwardInput";
     case HloInstruction::FusionKind::kCustom:
       return "kCustom";
   }
@@ -3148,12 +3122,6 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
   if (kind_name == "kTransposeDot") {
     return HloInstruction::FusionKind::kTransposeDot;
   }
-  if (kind_name == "kConvBackwardFilter") {
-    return HloInstruction::FusionKind::kConvBackwardFilter;
-  }
-  if (kind_name == "kConvBackwardInput") {
-    return HloInstruction::FusionKind::kConvBackwardInput;
-  }
   if (kind_name == "kCustom") {
     return HloInstruction::FusionKind::kCustom;
   }
@@ -3261,7 +3229,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
   result += "_";
   append_dims(rhs_dims, operand(1)->shape());
   result += "->";
-  append_dims(output_dims, shape());
+
+  // A convolution can be represented as a kConvolution HLO or as a CustomCall
+  // that returns a tuple, the first element of which is the result of the
+  // convolution.
+  Shape this_shape =
+      ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape();
+  append_dims(output_dims, this_shape);
   return result;
 }
 
index 5e89dc7..84b4696 100644 (file)
@@ -162,17 +162,14 @@ class HloPrintOptions {
 class HloInstruction {
  public:
   enum class FusionKind {
-    kLoop,                // Fused into a loop.
-    kInput,               // Op's input is fused into the op itself.
-    kOutput,              // Op's output is fused into the op itself.
-                          // REQUIRES: At least one operand buffer must be able
-                          // to alias the output buffer.
-    kTransposeDot,        // Fused into a dot with transposed operands.
-    kConvBackwardFilter,  // Fused into a backward filter convolution.
-    kConvBackwardInput,   // Fused into a backward input convolution.
-
-    kCustom,  // Custom category for backend-specific fusions that
-              // do not match any of the more specific ones.
+    kLoop,          // Fused into a loop.
+    kInput,         // Op's input is fused into the op itself.
+    kOutput,        // Op's output is fused into the op itself.
+                    // REQUIRES: At least one operand buffer must be able
+                    // to alias the output buffer.
+    kTransposeDot,  // Fused into a dot with transposed operands.
+    kCustom,        // Custom category for backend-specific fusions that
+                    // do not match any of the more specific ones.
   };
 
   ~HloInstruction();
@@ -466,14 +463,6 @@ class HloInstruction {
       tensorflow::gtl::ArraySlice<HloInstruction*> operands,
       HloComputation* fusion_computation);
 
-  // Creates a fusion instruction that represents backward convolution. This is
-  // similar to CreateFusion, but with extra arguments indicating the window and
-  // dimemsion mapping of the backward convolution.
-  static std::unique_ptr<HloInstruction> CreateFusionForBackwardConvolution(
-      const Shape& shape, FusionKind fusion_kind, const Window& window,
-      const ConvolutionDimensionNumbers& conv_dnums,
-      HloInstruction* fused_root);
-
   // Creates a call instruction that applies the given computation on the given
   // operands. "shape" is the resultant shape.
   static std::unique_ptr<HloInstruction> CreateCall(
@@ -1052,13 +1041,23 @@ class HloInstruction {
     return *padding_config_;
   }
 
-  // Returns data on the dimension numbers used for a convolution
-  // operation.
+  // Returns data on the dimension numbers used for a convolution operation,
+  // which may be a kConvolution instruction or a kCustomCall that implements a
+  // convolution.
   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
     CHECK(convolution_dimension_numbers_ != nullptr);
     return *convolution_dimension_numbers_;
   }
 
+  // Sets the convolution dimension numbers on this instruction.  In general you
+  // shouldn't need to call this; instead, specify the convolution dimension
+  // numbers when you create the instruction.
+  void set_convolution_dimension_numbers(
+      const ConvolutionDimensionNumbers& dnums) {
+    convolution_dimension_numbers_ =
+        MakeUnique<ConvolutionDimensionNumbers>(dnums);
+  }
+
   FftType fft_type() const {
     CHECK_EQ(HloOpcode::kFft, opcode_);
     return fft_type_;