From: Justin Lebar Date: Fri, 2 Feb 2018 05:34:18 +0000 (-0800) Subject: Internal change X-Git-Tag: upstream/v1.7.0~31^2~1070 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3be90a490c31d5a8fad70713e059bbb3e723e664;p=platform%2Fupstream%2Ftensorflow.git Internal change PiperOrigin-RevId: 184239740 --- diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 80c2eed..7df01f7 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -131,6 +131,7 @@ cc_library( "ir_emitter_context.h", ], deps = [ + ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_executable", @@ -262,6 +263,7 @@ cc_library( ], deps = [ ":buffer_allocations", + ":cudnn_convolution_runner", ":infeed_manager", ":ir_emission_utils", ":partition_assignment", @@ -309,9 +311,41 @@ cc_library( ) cc_library( - name = "convolution_folding", - srcs = ["convolution_folding.cc"], - hdrs = ["convolution_folding.h"], + name = "cudnn_convolution_algorithm_picker", + srcs = ["cudnn_convolution_algorithm_picker.cc"], + hdrs = ["cudnn_convolution_algorithm_picker.h"], + deps = [ + ":cudnn_convolution_runner", + ":gpu_executable", + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_runner", + srcs = ["cudnn_convolution_runner.cc"], + hdrs = ["cudnn_convolution_runner.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_rewriter", + srcs = ["cudnn_convolution_rewriter.cc"], + hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -325,15 +359,18 @@ cc_library( ) tf_cc_test( - name = "convolution_folding_test", - srcs = ["convolution_folding_test.cc"], + name = "cudnn_convolution_rewriter_test", + srcs = ["cudnn_convolution_rewriter_test.cc"], deps = [ - ":convolution_folding", + ":cudnn_convolution_rewriter", + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], ) @@ -446,7 +483,8 @@ cc_library( srcs = ["gpu_compiler.cc"], hdrs = ["gpu_compiler.h"], deps = [ - ":convolution_folding", + ":cudnn_convolution_algorithm_picker", + ":cudnn_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 899cc5c..f76f159 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -36,366 +37,69 @@ using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -ConvolveScratchAllocator::ConvolveScratchAllocator( - int device_ordinal, DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - -ConvolveScratchAllocator::~ConvolveScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { - constexpr int64 kConvolveScratchSize = 1LL << 32; // 4GB by default. - return kConvolveScratchSize; -} - -se::port::StatusOr> -ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); - total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); -} - -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind) { - switch (convolution_kind) { - case ConvolutionThunk::ConvolutionKind::kForward: - return "forward"; - case ConvolutionThunk::ConvolutionKind::kBackwardFilter: - return "backward_filter"; - case ConvolutionThunk::ConvolutionKind::kBackwardInput: - return "backward_input"; - } - return "unknown convolution kind"; -} - ConvolutionThunk::ConvolutionThunk( - ConvolutionKind convolution_kind, - const BufferAllocation::Slice& input_buffer, + CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, const Shape& input_shape, + const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo) + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), filter_buffer_(filter_buffer), output_buffer_(output_buffer), + tuple_result_buffer_(tuple_result_buffer), + scratch_buffer_(scratch_buffer), input_shape_(input_shape), filter_shape_(filter_shape), output_shape_(output_shape), window_(window), - dim_nums_(dim_nums) {} + dim_nums_(dim_nums), + algorithm_(algorithm) {} -tensorflow::Status ConvolutionThunk::ExecuteOnStream( +Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { - VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_); - VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }"; - VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }"; - VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }"; - VLOG(3) << "Window: { " << window_.ShortDebugString() << " }"; - - const int num_dimensions = window_.dimensions_size(); - CHECK_LE(num_dimensions, 3); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window_.dimensions()) { - CHECK_EQ(dim.padding_low(), dim.padding_high()); - } - - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - input_shape_.dimensions(dim_nums_.input_feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) - .set_input_feature_map_count( - filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension())) - .set_output_feature_map_count(filter_shape_.dimensions( - dim_nums_.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - output_shape_.dimensions(dim_nums_.output_feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor - .set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory filter_data( buffer_allocations.GetDeviceAddress(filter_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - return ConvolveWithTune(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, buffer_allocations, stream); -} - -tensorflow::Status ConvolutionThunk::Convolve( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result) { - bool launch_ok; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - launch_ok = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_data, output_descriptor, output_data, - convolution_descriptor, filter_descriptor, &filter_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - case ConvolutionKind::kBackwardInput: - launch_ok = stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_data, output_descriptor, - output_data, convolution_descriptor, input_descriptor, - &input_data, scratch_allocator, algorithm_config, - profile_result) - .ok(); - break; - case ConvolutionKind::kForward: - launch_ok = - stream - ->ThenConvolveWithAlgorithm( - input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, &output_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - } - if (launch_ok) { - return tensorflow::Status::OK(); - } - return InternalError( - "Unable to launch convolution for thunk %p with type %s and algorithm " - "(%lld, %lld)", - this, ConvolutionKindToString(convolution_kind_).c_str(), - algorithm_config.algorithm().algo_id(), - algorithm_config.algorithm_no_scratch().algo_id()); -} - -std::vector ConvolutionThunk::GetAlgorithms( - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { - std::vector algorithms; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); - break; - } - return algorithms; -} - -static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { - if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); - } - return tensorflow::strings::StrCat(algo.algo_id()); -} - -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output descriptors. This works around b/68264959, an -// integer overflow in cuDNNv5 and cuDNNv6. -static bool ShouldIncludeWinogradNonfusedAlgo( - const BatchDescriptor& input_descriptor, - const BatchDescriptor& output_descriptor) { - int64 batch = input_descriptor.count(); - int64 in_depths = input_descriptor.feature_map_count(); - int64 in_rows = input_descriptor.height(); - int64 in_cols = input_descriptor.width(); - int64 out_depths = output_descriptor.feature_map_count(); - - int64 total_size = 16 * std::ceil(batch / 16.0) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - int64 threshold = 1L << 31; - - return total_size < threshold; -} - -tensorflow::Status ConvolutionThunk::ConvolveWithTune( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const BufferAllocations& buffer_allocations, se::Stream* stream) { - // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (!best_algorithm_.has_value()) { - best_algorithm_.emplace(); - - // Auto-tuning either is disabled or only happens in the first run of this - // function. - VLOG(2) << "Profiling for best convolution algorithm used for " - "ConvolutionThunk: " - << this; - - bool with_winograd_nonfused = - ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); - - se::dnn::ProfileResult best_result; - se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = - GetAlgorithms(with_winograd_nonfused, stream->parent()); - for (auto algorithm : algorithms) { - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - se::dnn::ProfileResult profile_result; - VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk: " << this; - bool launch_ok = - Convolve(input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, - se::dnn::AlgorithmConfig(algorithm, algorithm), stream, - &scratch_allocator, &profile_result) - .ok(); - if (launch_ok && profile_result.is_valid()) { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " succeeded, taking " - << profile_result.elapsed_time_in_ms() - << "ms. (Best result: " << best_result.elapsed_time_in_ms() - << "ms)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalAllocatedBytes() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_without_scratch.elapsed_time_in_ms()) { - best_result_without_scratch = profile_result; - } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " failed."; - } - } - - if (best_result.is_valid()) { - best_algorithm_->set_algorithm(best_result.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " - "to the default algorithm."; - best_algorithm_->set_algorithm(AlgorithmDesc()); + se::DeviceMemoryBase scratch = + buffer_allocations.GetDeviceAddress(scratch_buffer_); + + se::dnn::AlgorithmConfig algorithm_config( + se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false)); + + TF_RETURN_IF_ERROR(RunCudnnConvolution( + convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, + filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, + stream)); + + // Figure out which of output/input/filter is the result produced by this op, + // and write the result tuple. + void* result_ptr = [&] { + switch (convolution_kind_) { + case CudnnConvKind::kForward: + return output_data.opaque(); + case CudnnConvKind::kBackwardInput: + return input_data.opaque(); + case CudnnConvKind::kBackwardFilter: + return filter_data.opaque(); } + }(); + void* ptrs[] = {result_ptr, scratch.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); - if (best_result_without_scratch.is_valid()) { - best_algorithm_->set_algorithm_no_scratch( - best_result_without_scratch.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm without scratch works with " - "profiling. Fall back " - "to the default algorithm."; - best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); - } - } - - { - VLOG(2) << "Using convolution algorithm (" - << AlgorithmToString(best_algorithm_->algorithm()) << ", " - << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) - << ") for ConvolutionThunk: " << this; - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - return Convolve(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, *best_algorithm_, stream, - &scratch_allocator, nullptr); + if (!stream->ok()) { + return InternalError("ConvolutionThunk::ExecuteOnStream failed."); } + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 46c94d0..ca9ef52 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,106 +31,47 @@ limitations under the License. namespace xla { namespace gpu { -// A one-time scratch allocator for forward and backward convolution. The -// scratch buffers allocated are released on destruction. -// -// Not thread-safe. -class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator { - public: - ConvolveScratchAllocator(int device_ordinal, - DeviceMemoryAllocator* memory_allocator); - - ~ConvolveScratchAllocator() override; - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; - - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - perftools::gputools::port::StatusOr> - AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; - - private: - const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - // This class stores everything that StreamExecutor needs to launch a BNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. class ConvolutionThunk : public Thunk { public: - // ConvolutionThunk performs one of the following types of convolution. - enum class ConvolutionKind { - kBackwardFilter, // Backward convolution for filter. - kBackwardInput, // Backward convolution for input. - kForward, // Forward convolution. - }; - - // Constructs a thunk for launching a DNN convolution. + // Constructs a thunk for launching a DNN convolution. When run, it will + // write a tuple (result, scratch_memory) into `tuple_result_buffer`. + // + // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that + // we should use the default (i.e. baseline) cudnn algorithm. + // + // Note that "output" here doesn't refer to the output from running this + // thunk, but rather to the "output" of a hypothetical forward convolution + // that corresponds to this input+filter+output triple. That is, the result + // generated by this thunk is "output" for forward convs, "input" for + // backward-input convs, and "filter" for backward-filter convs. + // // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(ConvolutionKind convolution_kind, + ConvolutionThunk(CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - // Does the convolution for the thunk on "stream". Auto-tuning happens on the - // first run of this function. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; - - // Returns true if the next run of ExecuteOnStream will do autotuning. If so, - // we want the GPU to be quiescent during autotuning, so as not to introduce - // noise in our results. - bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream*) override { - return !best_algorithm_.has_value(); - } - - // Return true if scratch memory is needed to execute the thunk, that is - // either the best algorithm hasn't been chosen or the best algorithm is not - // the same as the no-scratch algorithm. This is because that the execution - // of the thunk is asynchronous, and the scratch allocator goes out of - // scope before the thunk finishes execution. Returning true tells the stream - // executor to make future thunks wait for this thunk to avoid reusing the - // deallocated scratch memory until this thunk is done with it. - bool ShouldBlockFutureThunks() { - if (!best_algorithm_.has_value()) { - return true; - } - - const perftools::gputools::dnn::AlgorithmDesc& best_alg = - best_algorithm_->algorithm(); - const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg = - best_algorithm_->algorithm_no_scratch(); - return (!best_alg.is_default() || !no_scratch_best_alg.is_default() || - !(best_alg == no_scratch_best_alg)); - } + // Does the convolution for the thunk on "stream". + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: - tensorflow::Status ConvolveWithTune( - const perftools::gputools::dnn::BatchDescriptor& input_descriptor, - perftools::gputools::DeviceMemory input_data, - const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, - perftools::gputools::DeviceMemory filter_data, - const perftools::gputools::dnn::BatchDescriptor& output_descriptor, - perftools::gputools::DeviceMemory output_data, - const perftools::gputools::dnn::ConvolutionDescriptor& - convolution_descriptor, - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream); + class ScratchAllocator; - tensorflow::Status Convolve( + Status Convolve( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, perftools::gputools::DeviceMemory input_data, const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, @@ -139,40 +81,26 @@ class ConvolutionThunk : public Thunk { const perftools::gputools::dnn::ConvolutionDescriptor& convolution_descriptor, const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, - perftools::gputools::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, + perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator, perftools::gputools::dnn::ProfileResult* profile_result); - // Returns the convolve algorithms that can be used for this ConvolutionThunk. - std::vector GetAlgorithms( - bool with_winograd_nonfused, - perftools::gputools::StreamExecutor* stream_exec) const; - - // Fastest cuDNN convolution algorithm for this thunk learned from - // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value, indicating cuDNN's convolution will choose the best - // algorithm from some heuristics based on its parameters. - tensorflow::gtl::optional - best_algorithm_; - - const ConvolutionKind convolution_kind_; + const CudnnConvKind convolution_kind_; const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice filter_buffer_; const BufferAllocation::Slice output_buffer_; + const BufferAllocation::Slice tuple_result_buffer_; + const BufferAllocation::Slice scratch_buffer_; const Shape input_shape_; const Shape filter_shape_; const Shape output_shape_; const Window window_; - const ConvolutionDimensionNumbers dim_nums_; + int64 algorithm_; }; -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc new file mode 100644 index 0000000..621b2d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -0,0 +1,370 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = perftools::gputools; + +using se::DeviceMemoryBase; +using se::dnn::AlgorithmConfig; +using se::dnn::AlgorithmDesc; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + ~ScratchAllocator() override; + + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +ScratchAllocator::~ScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +se::port::StatusOr> ScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Failed to allocate %lld bytes on device %d.", + byte_size, device_ordinal_)); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output shapes. This works around b/68264959, an integer +// overflow in cuDNNv5 and cuDNNv6. +// +// TODO(jlebar): We shouldn't need this check for cuDNNv7. +bool ShouldIncludeWinogradNonfusedAlgo( + const Shape& input_shape, const Shape& output_shape, + const ConvolutionDimensionNumbers& dnums) { + int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); + int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); + int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); + int64 in_cols = + dnums.input_spatial_dimensions_size() == 1 + ? 1 + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); + int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); + + int64 total_size = CeilOfRatio(batch, int64{16}) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + + const int64 threshold = 1L << 31; + return total_size < threshold; +} + +std::vector GetAlgorithms(CudnnConvKind kind, + bool with_winograd_nonfused, + se::StreamExecutor* stream_exec_) { + std::vector algorithms; + switch (kind) { + case CudnnConvKind::kBackwardFilter: + CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kBackwardInput: + CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kForward: + CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused, + &algorithms)); + break; + } + + // Remove any algorithms with tensor math enabled. These have lower precision + // than regular algorithms, and we don't yet have a way to turn this on/off in + // XLA. + algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(), + [&](const AlgorithmDesc& a) { + return a.tensor_ops_enabled(); + }), + algorithms.end()); + + return algorithms; +} + +string AlgorithmToString(const AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + +string NumBytesToString(int64 bytes) { + return tensorflow::strings::StrCat( + tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); +} + +} // anonymous namespace + +// We could have caching here so that we don't redo this work for two identical +// convolutions. Unfortunately our cache key would have to be a tuple +// containing the protos passed to this function, and we have no utility for +// hashing protos. We could write our own hash functions, but they'd silently +// break if we ever added a field to one of the protos. Perhaps we could hack +// using the binary-encoded proto as the hash key, on the assumption that two +// protos being binary-equal is a sufficient, if not necessary, condition for +// proper equality. But that would still leave us open to having unnecessary +// cache misses and doing extra work. Overall, caching doesn't seem worth the +// trouble, but we may want to revisit this if we ever find a model where +// caching would speed up compilation a lot. +optional> +CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Create a stream for us to do our work on. + se::Stream stream{stream_exec_}; + stream.Init(); + const auto device_ordinal = stream_exec_->device_ordinal(); + + // allocator either points to this->allocator_ or, if that's null, to a + // StreamExecutorMemoryAllocator for stream_exec_. + DeviceMemoryAllocator* allocator; + optional se_allocator; + if (allocator_ != nullptr) { + allocator = allocator_; + } else { + se_allocator.emplace( + stream_exec_->platform(), + tensorflow::gtl::ArraySlice({stream_exec_})); + allocator = &*se_allocator; + } + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + // + // We don't put any data in these buffers, because (in theory, anyway) the + // speed of a conv isn't affected by the data being convolved. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + se::port::StatusOr input_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(input_shape)); + se::port::StatusOr filter_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(filter_shape)); + se::port::StatusOr output_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(output_shape)); + if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + LOG(WARNING) + << "Couldn't allocate space for input/filter/output of convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + + const bool use_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); + se::dnn::ProfileResult best_result; + int64 best_result_bytes_used = 0; + for (const AlgorithmDesc& alg : + GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + ScratchAllocator scratch_allocator(device_ordinal, allocator); + se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " + << instr->ToString(); + + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf.ValueOrDie()), + se::DeviceMemory(filter_buf.ValueOrDie()), + se::DeviceMemory(output_buf.ValueOrDie()), + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); + + if (launch_ok && profile_result.is_valid()) { + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) + << " succeeded, taking " << profile_result.elapsed_time_in_ms() + << "ms and using " << NumBytesToString(scratch_bytes_used) + << " of scratch (Best result: " + << best_result.elapsed_time_in_ms() << "ms, " + << NumBytesToString(best_result_bytes_used) << " of scratch)"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + best_result_bytes_used = scratch_bytes_used; + } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; + } + } + if (best_result.is_valid()) { + VLOG(2) << "Best algorithm for " << instr->ToString() << ": " + << AlgorithmToString(best_result.algorithm()) << ", takes " + << best_result.elapsed_time_in_ms() << "ms, and uses " + << best_result_bytes_used << "B of scratch memory."; + return std::make_pair(best_result.algorithm().algo_id(), + best_result_bytes_used); + } + + LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() + << " failed. Falling back to default algorithm."; + return nullopt; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( + HloInstruction* instr) { + CHECK(IsCustomCallToDnnConvolution(*instr)); + + const auto& call_target = instr->custom_call_target(); + const auto& lhs_shape = instr->operand(0)->shape(); + const auto& rhs_shape = instr->operand(1)->shape(); + const auto& conv_result_shape = instr->shape().tuple_shapes(0); + optional> alg_and_scratch_bytes; + if (call_target == kCudnnConvForwardCallTarget) { + alg_and_scratch_bytes = PickBestAlgorithm( + CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardInputCallTarget) { + alg_and_scratch_bytes = PickBestAlgorithm( + CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardFilterCallTarget) { + alg_and_scratch_bytes = PickBestAlgorithm( + CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instr->ToString(); + } + + if (!alg_and_scratch_bytes.has_value()) { + return false; + } + + int64 algorithm; + int64 scratch_bytes; + std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes; + + VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " + << NumBytesToString(scratch_bytes) + << " of scratch memory: " << instr->ToString(); + + // Replace instr with a new CustomCall which has the correct algorithm, and + // whose output shape has the appropriate amount of scratch memory. + HloComputation* computation = instr->parent(); + Shape new_call_shape = + ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {scratch_bytes})}); + HloInstruction* algorithm_hlo = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); + HloInstruction* new_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + new_call_shape, + {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo}, + instr->custom_call_target())); + new_call->set_window(instr->window()); + new_call->set_convolution_dimension_numbers( + instr->convolution_dimension_numbers()); + + // Repackage new_call so it has the same shape as the original call, namely + // (conv_result, u8[0]). + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple( + {computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_call_shape.tuple_shapes(0), new_call, 0)), + computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({})))})); + + TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); + return true; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( + HloComputation* computation) { + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + + bool changed = false; + for (auto* instr : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr)); + changed |= result; + } + return changed; +} + +StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h new file mode 100644 index 0000000..10e49da --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for +// each and adding explicit scratch space to the CustomCalls. +class CudnnConvolutionAlgorithmPicker : public HloPassInterface { + public: + // If the `allocator` parameter is not null, we will use it to allocate temp + // memory while timing the various convolution algorithms. If it's null, + // we'll use the default allocator on the StreamExecutor. + CudnnConvolutionAlgorithmPicker( + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) + : stream_exec_(stream_exec), allocator_(allocator) {} + + tensorflow::StringPiece name() const override { + return "cudnn-convolution-algorithm-picker"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnInstruction(HloInstruction* instr); + tensorflow::gtl::optional> PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + + perftools::gputools::StreamExecutor* stream_exec_; // never null + DeviceMemoryAllocator* allocator_; // may be null +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc similarity index 83% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index b0626ca..e0c73aa 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include #include @@ -33,14 +33,32 @@ namespace xla { namespace gpu { namespace { + +bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { + const ConvolutionDimensionNumbers& dnums = + conv->convolution_dimension_numbers(); + if (dnums.input_spatial_dimensions_size() > 3) { + return false; + } + + // CuDNN does not accept zero-element arguments + if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || + ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + return false; + } + + if (window_util::HasWindowReversal(conv->window())) { + return false; + } + return true; +} + // Try to match a backward filter pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardFilter(HloInstruction* conv) { +std::tuple MatchBackwardFilter( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -190,18 +208,15 @@ MatchBackwardFilter(HloInstruction* conv) { backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({conv}), - backward_conv_window, backward_conv_dnums); + return std::make_tuple(true, backward_conv_window, backward_conv_dnums); } // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardInput(HloInstruction* conv) { +std::tuple MatchBackwardInput( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); @@ -374,58 +389,82 @@ MatchBackwardInput(HloInstruction* conv) { dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, - std::vector({conv, reverse_filter}), - new_window, dnums); + return std::make_tuple(true, new_window, dnums); } -} // namespace -StatusOr ConvolutionFolding::Run(HloModule* module) { - HloComputation* entry_computation = module->entry_computation(); - std::vector convs; - for (auto* hlo : entry_computation->instructions()) { - if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo); - } - } +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - bool changed = false; - for (HloInstruction* conv : convs) { + HloInstruction* custom_call = [&]() -> HloInstruction* { bool match; - std::vector hlos_to_fuse; Window window; ConvolutionDimensionNumbers dnums; - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv); + + std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter, - window, dnums); - VLOG(2) << "to backward filter convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + return CreateCudnnConvBackwardFilter( + conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums); } - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums) = MatchBackwardInput(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput, - window, dnums); - VLOG(2) << "to backward input convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + // Backward input conv subsumes the conv plus the reverse in operand 1. + HloInstruction* reverse = conv->mutable_operand(1); + CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); + HloInstruction* rhs = reverse->mutable_operand(0); + + return CreateCudnnConvBackwardInput( + conv->shape(), conv->mutable_operand(0), rhs, window, dnums); } + + // If all else fails, try a forward convolution. + if (CanImplementAsCudnnForwardConv(conv)) { + return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers()); + } + + return nullptr; + }(); + + if (custom_call == nullptr) { + return false; + } + + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out + // the conv result and replace `conv` with it. + TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( + conv, + HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); + return true; +} + +// Rewrites the convolutions in the given computation into calls to cudnn. +// Returns true if it made any changes. +StatusOr RunOnComputation(HloComputation* computation) { + std::vector convs; + for (auto* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); + changed |= result; + } + return changed; +} +} // namespace + +StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h similarity index 63% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.h rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index f9c8987..0c0578d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -22,10 +22,12 @@ limitations under the License. namespace xla { namespace gpu { -class ConvolutionFolding : public HloPassInterface { +// Rewrites plain convolutions, backwards-filter convolutions, and +// backwards-input convolutions into CustomCall HLOs that call into cuDNN. +class CudnnConvolutionRewriter : public HloPassInterface { public: tensorflow::StringPiece name() const override { - return "convolution-folding"; + return "cudnn-convolution-rewriter"; } StatusOr Run(HloModule* module) override; @@ -34,4 +36,4 @@ class ConvolutionFolding : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc similarity index 82% rename from tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 34e6bdb..65588b6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { +namespace { -class ConvolutionFoldingTest : public HloTestBase { +namespace op = xla::testing::opcode_matchers; + +class CudnnConvolutionRewriterTest : public HloTestBase { public: - ConvolutionFoldingTest() { + CudnnConvolutionRewriterTest() { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -44,7 +50,8 @@ class ConvolutionFoldingTest : public HloTestBase { // the batch and feature dimension in the activations, and treat the batch // dimension in gradients as the input feature dimension in the filter. // - // TODO(jingyue): Add more tests on NCHW input order which TF also supports. + // TODO(jingyue): Add more tests on NCHW input order, which TF also + // supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); @@ -74,9 +81,8 @@ class ConvolutionFoldingTest : public HloTestBase { } protected: - bool FoldConvolution(HloModule* module) { - ConvolutionFolding convolution_folding; - return convolution_folding.Run(module).ValueOrDie(); + bool RunPass(HloModule* module) { + return CudnnConvolutionRewriter().Run(module).ValueOrDie(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -86,7 +92,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -108,14 +114,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -135,12 +140,17 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from block35 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -162,15 +172,15 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from inception v3 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -192,14 +202,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -221,14 +230,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -272,14 +280,15 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); // Low padding of the backward input convolution // = kernel_size - 1 - low padding on gradients. EXPECT_EQ(3, window_dim.padding_low()); @@ -291,7 +300,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -316,17 +325,16 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); } // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -347,8 +355,12 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_input_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from Inception V3 training. @@ -365,7 +377,8 @@ TEST_F(ConvolutionFoldingTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -397,14 +410,14 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); EXPECT_EQ(0, window_dim.padding_low()); EXPECT_EQ(0, window_dim.padding_high()); EXPECT_EQ(2, window_dim.stride()); @@ -413,7 +426,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -442,8 +455,12 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from //learning/brain/google/xla/benchmarks/resnet.py @@ -460,7 +477,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { // // We should fuse BC even though padding on activations is uneven, because // PadInsertion will canonicalize the fusion HLO. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -493,13 +510,12 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - const HloInstruction* backward_conv = entry_computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - backward_conv->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); const WindowDimension& backward_conv_col_dim = - backward_conv->window().dimensions(1); + entry_computation->root_instruction()->operand(0)->window().dimensions(1); EXPECT_EQ(0, backward_conv_col_dim.padding_low()); EXPECT_EQ(1, backward_conv_col_dim.padding_high()); } @@ -515,7 +531,7 @@ TEST_F(ConvolutionFoldingTest, // // We currently don't fuse BC because PadInsertion doesn't support negative // padding on the gradients of backward convolution (b/32744257). -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -544,9 +560,14 @@ TEST_F(ConvolutionFoldingTest, .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +} // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc new file mode 100644 index 0000000..f5f52cf --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -0,0 +1,221 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = ::perftools::gputools; + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::Stream; +using se::dnn::AlgorithmConfig; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::DimIndex; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; +using se::dnn::ProfileResult; + +// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, +// returning it (in its entirety) the first time Allocate() is called. +class ScratchBufAllocator : public se::ScratchAllocator { + public: + explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) + : scratch_(scratch) {} + + ~ScratchBufAllocator() override = default; + + int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { + return scratch_.size(); + } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override { + if (allocated_) { + return se::port::InternalError( + "Can't allocate twice from a ScratchBufAllocator."); + } + if (byte_size > scratch_.size()) { + return se::port::InternalError(tensorflow::strings::StrCat( + "Can't allocate ", byte_size, + " bytes from a ScratchBufAllocator of size ", scratch_.size())); + } + + allocated_ = true; + return se::DeviceMemory(scratch_); + } + + private: + se::DeviceMemoryBase scratch_; + bool allocated_ = false; +}; + +} // anonymous namespace + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + } +} + +Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape, + const Shape& filter_shape, const Shape& output_shape, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + AlgorithmConfig algorithm, Stream* stream, + ProfileResult* profile_result /*= nullptr*/) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, DeviceMemory input_buf, + DeviceMemory filter_buf, DeviceMemory output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, + Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); + VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; + VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; + VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + CHECK_EQ(F32, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dim.padding_low(), dim.padding_high()); + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + BatchDescriptor input_descriptor(effective_num_dimensions); + input_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor filter_descriptor(effective_num_dimensions); + filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + .set_input_feature_map_count( + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()); + } + + BatchDescriptor output_descriptor(effective_num_dimensions); + output_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + + switch (kind) { + case CudnnConvKind::kForward: + stream->ThenConvolveWithAlgorithm( + input_descriptor, input_buf, filter_descriptor, filter_buf, + convolution_descriptor, output_descriptor, &output_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardInput: + stream->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_buf, output_descriptor, output_buf, + convolution_descriptor, input_descriptor, &input_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardFilter: + stream->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_buf, output_descriptor, output_buf, + convolution_descriptor, filter_descriptor, &filter_buf, + scratch_allocator, algorithm, profile_result); + break; + } + + if (!stream->ok()) { + return InternalError( + "Unable to launch convolution with type %s and algorithm (%lld, %lld)", + CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + algorithm.algorithm_no_scratch().algo_id()); + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h new file mode 100644 index 0000000..b101f76 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This file contains low-level routines for running cudnn convolutions. + +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter +}; + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + +// Calls into cudnn to run the specified convolution. +// +// Note that depending on the value of CudnnConvKind, the result of this call +// may be written into input_buf, filter_buf, or output_buf! +// +// At the moment we only support cudnn convolutions over floats. +// +// We provide one overload which takes a scratch buffer, and another which takes +// an allocator which is responsible for allocating the scratch space. In +// theory the second one shouldn't be necessary -- users of this function could +// just ask cudnn how much scratch space it needs for a particular convolution. +// But in practice, StreamExecutor does not expose such an API, and in the name +// of parsimony, perhaps it's better not to add it. Instead, the first time you +// call a convolution, you should call the version that takes a scratch +// allocator and take note of how much memory is used. The next time you call +// the same conv, you can provide an explicitly preallocated scratch buffer of +// that size, if you like. +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, + perftools::gputools::DeviceMemory input_buf, + perftools::gputools::DeviceMemory filter_buf, + perftools::gputools::DeviceMemory output_buf, + perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, + perftools::gputools::DeviceMemory input_buf, + perftools::gputools::DeviceMemory filter_buf, + perftools::gputools::DeviceMemory output_buf, + perftools::gputools::ScratchAllocator* scratch_allocator, + const Window& window, const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 07543d4..12ec266 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -35,8 +35,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -127,7 +128,9 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { +tensorflow::Status OptimizeHloModule(HloModule* hlo_module, + se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -143,6 +146,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { // most ops. pipeline.AddPass(BF16, F32); pipeline.AddPass(); + { auto& pass = pipeline.AddPass>("simplification"); @@ -173,7 +177,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { pass.AddPass(); pass.AddPass(); } - pipeline.AddPass(); + pipeline.AddPass( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { @@ -185,6 +189,58 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + + { + // Convert convolutions into CustomCalls to cudnn, then canonicalize them + // (PadInsertion). + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + pipeline.AddPass(); + + // Choose the fastest algorithm for each conv. + // + // In theory doing this here is way too early: It needs to happen after + // layout assignment, because the layout of the inputs/outputs affects the + // speed of the conv. But currently we only allow only one input/output + // layout when calling cudnn, so there's no ambiguity. + // + // We pick the algorithm at this early stage so we can generate better HLO. + // After CudnnConvolutionRewriter, our convolutions are CustomCalls which + // return a tuple (conv_result, scratch_memory), and the each conv uses 0 + // bytes of scratch: + // + // customcall = (f32[...], f32[0]) + // return gte(customcall, 0) + // + // The algorithm picker then chooses the best algorithm, and potentially + // increases the scratch space. It replaces customcall with new_tuple, + // giving us the following: + // + // new_customcall = (f32[...], f32[N]) + // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) + // return gte(new_tuple, 0) + // + // The new tuple and gte instructions then be simplified away, because + // nobody is expected to use the scratch value. + // + // However, if we were to run CudnnConvolutionAlgorithmPicker after layout + // assignment, fusion would already have run, and the gte(customcall, 0) + // would probably already be into a fusion node. We can't simplify across + // HloComputation boundaries, so in this case we wouldn't be able to + // simplify away the new_tuple bits. + // + // We'll need to revisit this if we ever allow multiple layouts for the + // inputs/outputs of a cudnn convolution. + pipeline.AddPass(stream_exec, + device_allocator); + // Clean up new_tuple described above. + pipeline.AddPass(); + pipeline.AddPass(); + + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { HloPassFix fusion("fusion"); fusion.AddInvariantChecker(); @@ -212,9 +268,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* /*device_allocator*/) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output @@ -222,9 +276,10 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); - pipeline.AddPass(); + pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -417,7 +472,8 @@ StatusOr> GpuCompiler::RunHloPasses( XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); Tracing::TraceMe annotation("HLO Transforms", module->name(), /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get())); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); } @@ -428,8 +484,7 @@ StatusOr> GpuCompiler::RunBackend( TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get(), stream_exec, - device_allocator)); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -464,8 +519,9 @@ StatusOr> GpuCompiler::RunBackend( /*color_alignment=*/[](LogicalBuffer::Color) { return kCudaMallocAlignBytes; })); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. + // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() + // include headers, so no need for us to print them ourselves. + XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_optimized_hlo_proto_to = diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index e3b493c..88bf5a7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,6 +78,12 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } + } else if (IsCustomCallToDnnConvolution(*hlo)) { + // The last argument to a CUDNN convolution is its algorithm, which must + // be an HLO constant -- it shouldn't be copied. + for (int64 i = 0; i < hlo->operand_count() - 1; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } } else if (ImplementedAsLibraryCall(*hlo)) { // For all other library calls, materialize all the operands into memory. for (int64 i = 0; i < hlo->operand_count(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 58915f1..89f1e62 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -28,122 +28,114 @@ limitations under the License. namespace xla { namespace gpu { +// cuDNN convolutions are called with specific layouts on the input, output, +// and filter: +// +// input: DataLayout::kBatchDepthYX +// output: DataLayout::kBatchDepthYX +// filter: FilterLayout::kOutputInputYX +// +// The order dimensions in the constant name is major-to-minor (eg, the +// most-major dimension of the input is batch, most-minor is X). The +// specific dimension numbers these named dimensions correspond to is +// determined by the ConvolutionDimensionNumbers argument. Y is spatial +// dimension 0, and X is spatial dimension 1. +// +// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. +static Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints) { + CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); + Shape input_shape; + Shape filter_shape; + Shape output_shape; + const auto& target = instr->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->shape().tuple_shapes(0); + } else if (target == kCudnnConvBackwardInputCallTarget) { + input_shape = instr->shape().tuple_shapes(0); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->operand(0)->shape(); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->shape().tuple_shapes(0); + output_shape = instr->operand(1)->shape(); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + + // Construct minor-to-major dimension orders for operands and result. + // cuDNN's convolution APIs support the BDYX layout for activations/output + // and the OIYX layout for weights. + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN + // calls after we switch to cuDNN v5. + const ConvolutionDimensionNumbers& dimension_numbers = + instr->convolution_dimension_numbers(); + std::vector input_layout; + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; + --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); + } + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); + *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); + + std::vector filter_layout; + for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; + --i) { + filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); + } + filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); + filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); + *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); + + std::vector output_layout; + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; + --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + } + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); + *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); + + // The custom call returns a tuple of (actual_result, scratch_buffer); + // call_result_buf is the logical buffer for actual_result, the thing that + // contains the result of the conv call. + TF_ASSIGN_OR_RETURN(const LogicalBuffer* call_result_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instr, /*index=*/{0})); + + // Set layouts of the instructions' shapes. + if (target == kCudnnConvForwardCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardInputCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + return Status::OK(); +} + Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { for (auto* instruction : constraints->computation()->instructions()) { - // cuDNN is called with specific layouts on the input, output, and filter: - // - // input: DataLayout::kBatchDepthYX - // output: DataLayout::kBatchDepthYX - // filter: FilterLayout::kOutputInputYX - // - // The order dimensions in the constant name is major-to-minor (eg, the - // most-major dimension of the input is batch, most-minor is X). The - // specific dimension numbers these named dimensions correspond to is - // determined by the ConvolutionDimensionNumbers argument. Y is spatial - // dimension 0, and X is spatial dimension 1. - // - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. - if (ImplementedAsDnnConvolution(*instruction)) { - HloInstruction* input = nullptr; - HloInstruction* filter = nullptr; - HloInstruction* output = nullptr; - if (instruction->opcode() == HloOpcode::kConvolution) { - input = instruction->mutable_operand(0); - filter = instruction->mutable_operand(1); - output = instruction; - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - input = instruction->mutable_operand(0); - filter = instruction; - output = instruction->mutable_operand(1); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - input = instruction; - filter = instruction->mutable_operand(1); - output = instruction->mutable_operand(0); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } - - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instruction->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; - i >= 0; --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - Shape input_shape(input->shape()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; - i >= 0; --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back( - dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back( - dimension_numbers.kernel_output_feature_dimension()); - Shape filter_shape(filter->shape()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; - i >= 0; --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); - } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - Shape output_shape(output->shape()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); - - // Set layouts of the instructions' shapes. - if (instruction->opcode() == HloOpcode::kConvolution) { - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, output, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, output, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, output)); - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, filter, 0)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(filter_shape, filter)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, filter, 1)); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(input_shape, input)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, input, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, input, 1)); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } + if (IsCustomCallToDnnConvolution(*instruction)) { + TF_RETURN_IF_ERROR( + AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } } return Status::OK(); @@ -151,9 +143,12 @@ Status GpuLayoutAssignment::AddBackendConstraints( bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) { - // Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - return !IsCustomCallToDnnBatchNorm(*instruction); + // - Inputs to cudnn batchnorm custom calls don't need the major-first layout + // (i.e. {n, n-1, ...0}) -- we can handle any layout. + // - Inputs to cudnn convolution require custom layouts handled in + // AddBackendConstraints. + return !IsCustomCallToDnnBatchNorm(*instruction) && + !IsCustomCallToDnnConvolution(*instruction); } Status GpuLayoutAssignment::PropagateOperandConstraint( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1d47ffd..2d6dad2 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -137,49 +137,6 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { - HloComputation::Builder builder(TestName()); - auto input = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input")); - auto filter = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter")); - - Window conv_window; - WindowDimension* conv_window_row = conv_window.add_dimensions(); - conv_window_row->set_size(1); - WindowDimension* conv_window_col = conv_window.add_dimensions(); - conv_window_col->set_size(2); - conv_window_col->set_padding_high(1); - - ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_input_batch_dimension(0); - conv_dnums.set_output_batch_dimension(0); - conv_dnums.set_input_feature_dimension(1); - conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_input_spatial_dimensions(2); - conv_dnums.add_output_spatial_dimensions(2); - conv_dnums.add_input_spatial_dimensions(3); - conv_dnums.add_output_spatial_dimensions(3); - conv_dnums.set_kernel_output_feature_dimension(0); - conv_dnums.set_kernel_input_feature_dimension(1); - conv_dnums.add_kernel_spatial_dimensions(2); - conv_dnums.add_kernel_spatial_dimensions(3); - - auto conv = builder.AddInstruction( - HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), - input, filter, conv_window, conv_dnums)); - auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0})); - builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - TEST_F(InstructionFusionTest, GetTupleElementFused) { HloComputation::Builder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 76566a9..2f65edf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -90,43 +90,6 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return false; } -bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { - return false; - } - - // Forward convolution. - if (hlo.opcode() == HloOpcode::kConvolution) { - const ConvolutionDimensionNumbers& dnums = - hlo.convolution_dimension_numbers(); - if (dnums.input_spatial_dimensions_size() > 3) { - return false; - } - - // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) || - ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) { - return false; - } - - if (window_util::HasWindowReversal(hlo.window())) { - return false; - } - - return true; - } - - // Backward convolution. - if (hlo.opcode() == HloOpcode::kFusion && - (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter || - hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) { - return true; - } - - return false; -} - const char* const kCudnnBatchNormForwardInferenceCallTarget = "__cudnn$batchNormalizationForwardInference"; const char* const kCudnnBatchNormForwardTrainingCallTarget = @@ -144,9 +107,76 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { target == kCudnnBatchNormBackwardCallTarget; } +const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; +const char* const kCudnnConvBackwardInputCallTarget = + "__cudnn$convBackwardInput"; +const char* const kCudnnConvBackwardFilterCallTarget = + "__cudnn$convBackwardFilter"; + +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const auto& target = hlo.custom_call_target(); + return target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget || + target == kCudnnConvBackwardFilterCallTarget; +} + bool ImplementedAsLibraryCall(const HloInstruction& hlo) { - return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo) || - IsCustomCallToDnnBatchNorm(hlo); + return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || + IsCustomCallToDnnConvolution(hlo); +} + +static HloInstruction* CreateCudnnConv( + const char* call_target, const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dnums) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn + // algorithm to use. It's up to a later pass to choose the algorithm, so to + // indicate that we haven't yet made a choice, we speicfy -1 for that arg. + HloInstruction* negative_one = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction* custom_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + call_shape, {lhs, rhs, negative_one}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + return custom_call; +} + +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, + window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, + reverse_filter, window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, + output, window, dnums); } bool IsReductionToVector(const HloInstruction& reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index d24ed98..7ad9680 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -22,6 +22,9 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they +// don't belong in "ir_emission_utils". + namespace xla { namespace gpu { @@ -30,9 +33,6 @@ constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. bool ImplementedAsGemm(const HloInstruction& hlo); -// Returns true if `hlo` will be implemented as a call to cuDNN convolution. -bool ImplementedAsDnnConvolution(const HloInstruction& hlo); - // A call to cuDNN for batch normalization is represented as CustomCall HLO with // a call target equal to one of these strings. // @@ -58,6 +58,60 @@ extern const char* const kCudnnBatchNormBackwardCallTarget; // sequence of generic HLOs or to a cuDNN CustomCall. bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); +// A call to cuDNN for convolution (forward, backward filter, or backward input) +// is represented as a CustomCall HLO with a call target equal to one of these +// strings. +// +// These CustomCalls have window() and convolution_dimension_numbers() set like +// regular convolution ops. They have the same LHS and RHS operands, plus one +// additional int64 operand, representing which cudnn algorithm to run. This +// operand must be an HLO constant. A value of -1 means that the implementation +// is free to choose the best algorithm it can. +// +// These calls output a tuple (conv_result, scratch_memory), where conv_result +// is the actual result of the convolution, and scratch_memory is temporary +// memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value +// is not well-defined. +// +// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later +// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// algorithm for each conv and sets the amount of scratch space needed. +// +// (Representing the scratch memory as an output may seem strange at first, but +// it's quite sensible, from a certain point of view. The scratch buffer is a +// location in memory that the conv can write into, but which it can't legally +// read from, at least until it's written something first. But that's exactly +// the definition of an output buffer.) +extern const char* const kCudnnConvForwardCallTarget; +extern const char* const kCudnnConvBackwardInputCallTarget; +extern const char* const kCudnnConvBackwardFilterCallTarget; + +// Returns true if `hlo` will be implemented as a call to a cuDNN convolution +// routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a +// kConvolution opcode. +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); + +// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. +// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If +// you want just the conv result, you'll need to get-tuple-element the value +// returned by this function. +// +// The created cudnn call will use the default cudnn algorithm and no scratch +// space. +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums); + // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 3aa1784..9031a83 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -336,9 +336,6 @@ class IrEmitterUnnested : public IrEmitter { // Thunk object. std::unique_ptr BuildKernelThunk(const HloInstruction* inst); - // Returns a ConvolutionThunk that calls DNN to implement `inst`. - std::unique_ptr BuildConvolutionThunk(const HloInstruction* inst); - // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bd428f8..4072573 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -278,10 +279,6 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - if (ImplementedAsDnnConvolution(*convolution)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); return IrEmitter::HandleConvolution(convolution); } @@ -380,6 +377,71 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return Status::OK(); } + if (IsCustomCallToDnnConvolution(*custom_call)) { + const auto& assn = ir_emitter_context_->buffer_assignment(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); + auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + auto tuple_result_slice = GetAllocationSlice(*custom_call); + auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + + const HloInstruction* algorithm_inst = custom_call->operand(2); + CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); + int64 algorithm = algorithm_inst->literal().Get({}); + + const auto& target = custom_call->custom_call_target(); + std::unique_ptr thunk; + if (target == kCudnnConvForwardCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kForward, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/conv_result_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, custom_call); + } else if (target == kCudnnConvBackwardInputCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardInput, + /*input_buffer=*/conv_result_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/lhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/lhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, custom_call); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardFilter, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/conv_result_slice, + /*output_buffer=*/rhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, + /*output_shape=*/rhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, custom_call); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + + thunk_sequence_->emplace_back(std::move(thunk)); + return Status::OK(); + } + return IrEmitter::HandleCustomCall(custom_call); } @@ -500,10 +562,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - if (ImplementedAsDnnConvolution(*fusion)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); return IrEmitter::HandleFusion(fusion); } @@ -2011,52 +2069,6 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); } -std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( - const HloInstruction* inst) { - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - if (inst->opcode() == HloOpcode::kConvolution) { - // Forward covolution. - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kForward, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/inst->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - } - - // Backward filter convolution, which takes the input (activations) and the - // gradients, and computes the filter. - CHECK_EQ(HloOpcode::kFusion, inst->opcode()); - switch (inst->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardFilter, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*inst), - /*output_buffer=*/GetAllocationSlice(*rhs), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/inst->shape(), - /*output_shape=*/rhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - case HloInstruction::FusionKind::kConvBackwardInput: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardInput, - /*input_buffer=*/GetAllocationSlice(*inst), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*lhs), - /*input_shape=*/inst->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/lhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - default: - LOG(FATAL) << "Not a convolution-fusion"; - } -} - std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2923a79..25846dc 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -27,7 +27,7 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); + CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -47,6 +47,12 @@ HloInstruction* MaybePaddedAndSlicedInput( window_util::HasBaseDilation(conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. + // + // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of + // moving all the padding into an explicit pad op, we should keep as much + // padding inside of cudnn as possible, on the assumption that padding + // within cudnn is basically free, whereas a kPad's cost increases as the + // amount of padding increases. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { @@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { dim->set_window_dilation(1); } + // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract + // out the shape of conv_result. + Shape old_conv_shape = conv->shape().tuple_shapes(0); + VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = HloInstruction::CreateConvolve( - conv->shape(), new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, + new_conv_window, + conv->convolution_dimension_numbers()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); - TF_CHECK_OK( - conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv))); + TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); return true; } @@ -190,6 +199,8 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { + CHECK_EQ(backward_conv->custom_call_target(), + kCudnnConvBackwardFilterCallTarget); if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } @@ -202,15 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { @@ -222,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // cuDNN convolution (which doesn't support negative padding) to fail. return false; } - // If the backward convolution has uneven padding on the activations, we - // move some padding on the larger end to "internal" padding, so that the - // backward convolution produces larger weight gradients which get sliced - // later. Therefore, the amount of new padding (low or high) is the minimum - // of the amount of old padding low and old padding high. + // Compute the new, even padding for the backward conv operation. int64 new_conv_padding = std::min(padding_low, padding_high); int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -237,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Since we move some padding from the backward convolution to the kPad, we // need to accordingly reduce the padding amount of the backward convolution // and its inner forward convolution. - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); + auto* new_dim = new_backward_conv_window.mutable_dimensions(i); + new_dim->set_padding_low(new_conv_padding); + new_dim->set_padding_high(new_conv_padding); } // Create a new backward convolution replacing the old one. @@ -260,19 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), input, padding, input_padding_config)); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - padded_input->shape(), output->shape(), new_forward_conv_window, - forward_conv_dnums) - .ConsumeValueOrDie(), - padded_input, output, new_forward_conv_window, forward_conv_dnums)); - - // Fuse the new forward convolution to the new backward convolution. - HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, - new_backward_conv_window, backward_conv_dnums); + // The shape of the backward_conv CustomCall is a tuple (conv_result, + // scratch_buffer). Extract out the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( + backward_conv_shape, padded_input, output, new_backward_conv_window, + backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -289,14 +280,15 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return false; } - HloInstruction* forward_conv = backward_conv->fused_expression_root(); - HloInstruction* reverse_filter = forward_conv->mutable_operand(1); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); + + // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). + // Get the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + + Shape new_backward_conv_shape = backward_conv_shape; for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); @@ -315,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( // where the amount of padding low is larger, we can canonicalize it to // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) // [A] = Slice([B A]) - // For consistency, we need to increase the low padding of the inner - // convolution by 1 as well because the input is larger now. if (padding_low > padding_high) { IncreasePaddingLowBy(padding_high - padding_low, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(padding_low - padding_high, - new_forward_conv_window.mutable_dimensions(i)); } else if (padding_low < padding_high) { IncreasePaddingHighBy(padding_low - padding_high, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(padding_high - padding_low, - new_forward_conv_window.mutable_dimensions(i)); } + // Decreasing the padding by X *increases* the size of our output by X. + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + new_backward_conv_shape.set_dimensions( + dim, new_backward_conv_shape.dimensions(dim) + + std::abs(padding_low - padding_high)); } // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_reverse_filter = - computation->AddInstruction(HloInstruction::CreateReverse( - filter->shape(), filter, reverse_filter->dimensions())); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - output->shape(), new_reverse_filter->shape(), - new_forward_conv_window, forward_conv_dnums) - .ConsumeValueOrDie(), - output, new_reverse_filter, new_forward_conv_window, - forward_conv_dnums)); + + HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( + new_backward_conv_shape, output, filter, new_backward_conv_window, + backward_conv_dnums); + + // The CustomCall created above returns a tuple (conv_result, scratch_memory). + // Extract out the two elements. HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv, new_reverse_filter}, - HloInstruction::FusionKind::kConvBackwardInput, - new_backward_conv_window, backward_conv_dnums); + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_shape, new_backward_conv_call, 0)); + HloInstruction* new_backward_conv_scratch = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_call->shape().tuple_shapes(1), + new_backward_conv_call, 1)); // Slice the new backward convolution. // @@ -377,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( } // Replace the old backward convolution with the slice. - CHECK(ShapeUtil::Compatible( + Shape slice_shape = ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, limit_indices, strides) - .ConsumeValueOrDie(), - backward_conv->shape())); + .ConsumeValueOrDie(); + CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) + << ShapeUtil::HumanString(slice_shape) << " vs " + << ShapeUtil::HumanString(backward_conv_shape); - auto slice = - HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices, strides); + HloInstruction* slice = computation->AddInstruction( + HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, + start_indices, limit_indices, strides)); + HloInstruction* new_tuple = computation->AddInstruction( + HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); VLOG(1) << "Canonicalizing backward input conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << slice->ToString(); + << new_tuple->ToString(); - TF_CHECK_OK( - computation->ReplaceWithNewInstruction(backward_conv, std::move(slice))); + TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); return true; } @@ -400,18 +392,17 @@ StatusOr PadInsertion::Run(HloModule* module) { bool changed = false; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kConvolution) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (instruction->opcode() == HloOpcode::kFusion) { - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - changed |= CanonicalizeBackwardFilterConvolution(instruction); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - changed |= CanonicalizeBackwardInputConvolution(instruction); - break; - default: - break; + if (IsCustomCallToDnnConvolution(*instruction)) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index a63affa..5432419 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -461,20 +461,6 @@ HloInstruction* HloComputation::CreateFusionInstruction( return fusion_instruction; } -HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums) { - CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind || - HloInstruction::FusionKind::kConvBackwardInput == fusion_kind); - HloInstruction* root = instructions_to_fuse.front(); - HloInstruction* fusion_instruction = - AddInstruction(HloInstruction::CreateFusionForBackwardConvolution( - root->shape(), fusion_kind, window, conv_dnums, root)); - FuseInstructionsInto(instructions_to_fuse, fusion_instruction); - return fusion_instruction; -} - StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { @@ -577,8 +563,11 @@ Status HloComputation::ReplaceWithNewInstruction( Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction) { - TF_RET_CHECK(ShapeUtil::Compatible(old_instruction->shape(), - new_instruction->shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) + << ShapeUtil::HumanString(old_instruction->shape()) << " vs " + << ShapeUtil::HumanString(new_instruction->shape()); + VLOG(10) << "transformed " << old_instruction->ToString() << " to " << new_instruction->ToString(); // Try to add metadata for HLO instructions that are created to replace diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 6436815..061c59a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -224,15 +224,6 @@ class HloComputation { tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction::FusionKind fusion_kind); - // Creates a fusion instruction that represents a backward convolution. This - // is similar to CreateFusionInstruction but takes window and conv_dnums which - // indicate the window and convolution dimension numbers of the backward - // convolution. - HloInstruction* CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums); - // Create a deep copy of the given instruction and return the instruction // producing the copied result. All instructions performing the copy are added // to the computation. For array-shaped values, this method trivially returns diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index cd54eb7..9cd5a1e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -469,7 +469,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - return Unimplemented("Custom-call is not implemented for HLO cost analysis."); + // We can't do anything sane with CustomCalls, since we don't know what they + // do, and returning an error status will stop iteration over this + // computation, which is probably also not what we want. So just punt and + // return OK. This will cause all of the properties to be reported as 0, + // which is fine. + current_should_compute_bottleneck_time_ = false; + return Status::OK(); } Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a889c35..fac6b43 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -763,16 +763,13 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot and -// backward-conv fusions, since those fusions are really just describing a type -// of dot/conv rather than generating a novel computation. +// We put the fusion kind into the instruction's name for transpose-dot fusions, +// since those fusions are really just describing a type of dot rather than +// generating a novel computation. static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { switch (fusion_kind) { case HloInstruction::FusionKind::kTransposeDot: return "dot_fusion"; - case HloInstruction::FusionKind::kConvBackwardInput: - case HloInstruction::FusionKind::kConvBackwardFilter: - return "conv_fusion"; default: return "fusion"; } @@ -804,18 +801,6 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { return instruction; } -/* static */ std::unique_ptr -HloInstruction::CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) { - std::unique_ptr fusion = - CreateFusion(shape, fusion_kind, fused_root); - fusion->window_ = MakeUnique(window); - fusion->convolution_dimension_numbers_ = - MakeUnique(conv_dnums); - return fusion; -} - void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -2318,7 +2303,7 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - auto conv_category = [&] { + if (opcode() == HloOpcode::kConvolution) { string category = "convolution"; if (window_util::HasBaseDilation(window())) { category += " base-dilated"; @@ -2327,10 +2312,6 @@ string HloInstruction::ToCategory() const { category += " window-dilated"; } return category; - }; - - if (opcode() == HloOpcode::kConvolution) { - return conv_category(); } // Give transpose-dot and backwards-conv fusions the categories "dot" and @@ -2348,9 +2329,6 @@ string HloInstruction::ToCategory() const { return "output fusion"; case FusionKind::kTransposeDot: return "dot"; - case FusionKind::kConvBackwardFilter: - case FusionKind::kConvBackwardInput: - return conv_category(); case FusionKind::kCustom: return "custom fusion"; } @@ -3125,10 +3103,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: return "kTransposeDot"; - case HloInstruction::FusionKind::kConvBackwardFilter: - return "kConvBackwardFilter"; - case HloInstruction::FusionKind::kConvBackwardInput: - return "kConvBackwardInput"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3148,12 +3122,6 @@ StatusOr StringToFusionKind( if (kind_name == "kTransposeDot") { return HloInstruction::FusionKind::kTransposeDot; } - if (kind_name == "kConvBackwardFilter") { - return HloInstruction::FusionKind::kConvBackwardFilter; - } - if (kind_name == "kConvBackwardInput") { - return HloInstruction::FusionKind::kConvBackwardInput; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } @@ -3261,7 +3229,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(output_dims, shape()); + + // A convolution can be represented as a kConvolution HLO or as a CustomCall + // that returns a tuple, the first element of which is the result of the + // convolution. + Shape this_shape = + ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); + append_dims(output_dims, this_shape); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5e89dc7..84b4696 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -162,17 +162,14 @@ class HloPrintOptions { class HloInstruction { public: enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kConvBackwardFilter, // Fused into a backward filter convolution. - kConvBackwardInput, // Fused into a backward input convolution. - - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, // Fused into a loop. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. + kTransposeDot, // Fused into a dot with transposed operands. + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -466,14 +463,6 @@ class HloInstruction { tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation); - // Creates a fusion instruction that represents backward convolution. This is - // similar to CreateFusion, but with extra arguments indicating the window and - // dimemsion mapping of the backward convolution. - static std::unique_ptr CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, - HloInstruction* fused_root); - // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( @@ -1052,13 +1041,23 @@ class HloInstruction { return *padding_config_; } - // Returns data on the dimension numbers used for a convolution - // operation. + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { CHECK(convolution_dimension_numbers_ != nullptr); return *convolution_dimension_numbers_; } + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + FftType fft_type() const { CHECK_EQ(HloOpcode::kFft, opcode_); return fft_type_;