Add heuristic on picking NHWC layout for (V100, fp16) convolutions.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 May 2018 00:48:21 +0000 (17:48 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 25 May 2018 00:50:57 +0000 (17:50 -0700)
Also move AlgorithmPicker after layout assignment, as now
cudnn_convolution_runner will return failures on invalid input layouts.

Also add a backend debug option to switch the layout heuristic. By default
it has the old behavior (all NCHW).

PiperOrigin-RevId: 197983747

13 files changed:
tensorflow/compiler/xla/layout_util.cc
tensorflow/compiler/xla/layout_util.h
tensorflow/compiler/xla/service/gpu/BUILD
tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
tensorflow/compiler/xla/service/gpu/gpu_options.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/gpu_options.h [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/stream_executor_util.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/stream_executor_util.h [new file with mode: 0644]
tensorflow/compiler/xla/tests/BUILD

index a76fdcd..89cafa1 100644 (file)
@@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer(
   return layout;
 }
 
+/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
+    tensorflow::gtl::ArraySlice<int64> major_to_minor) {
+  Layout layout;
+  layout.set_format(DENSE);
+  for (int i = major_to_minor.size() - 1; i >= 0; i--) {
+    layout.add_minor_to_major(major_to_minor[i]);
+  }
+  return layout;
+}
+
 /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
   Layout layout;
   layout.set_format(SPARSE);
index d3d6a2c..739bbe7 100644 (file)
@@ -36,6 +36,10 @@ class LayoutUtil {
   // convenience function for protobuf construction.)
   static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
 
+  // Similar to MakeLayout, but take indices in reverse order.
+  static Layout MakeLayoutFromMajorToMinor(
+      tensorflow::gtl::ArraySlice<int64> major_to_minor);
+
   // Creates a sparse layout with the given maximum number of elements. (This is
   // a convenience function for protobuf construction.)
   static Layout MakeSparseLayout(int64 max_sparse_elements);
index aafb61b..ffb1af2 100644 (file)
@@ -338,6 +338,7 @@ cc_library(
     srcs = ["cudnn_convolution_runner.cc"],
     hdrs = ["cudnn_convolution_runner.h"],
     deps = [
+        ":stream_executor_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status",
         "//tensorflow/compiler/xla:status_macros",
@@ -590,14 +591,18 @@ cc_library(
     srcs = ["gpu_layout_assignment.cc"],
     hdrs = ["gpu_layout_assignment.h"],
     deps = [
+        ":gpu_options",
         ":ir_emission_utils",
+        ":stream_executor_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:computation_layout",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:layout_assignment",
         "//tensorflow/core:lib",
+        "//tensorflow/core:stream_executor_no_cuda",
     ],
 )
 
@@ -694,6 +699,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "gpu_options",
+    srcs = ["gpu_options.cc"],
+    hdrs = ["gpu_options.h"],
+    deps = [
+        "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+cc_library(
+    name = "stream_executor_util",
+    srcs = ["stream_executor_util.cc"],
+    hdrs = ["stream_executor_util.h"],
+    deps = [
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/core:stream_executor_no_cuda",
+    ],
+)
+
 tf_cc_test(
     name = "gpu_hlo_support_checker_test",
     srcs = ["gpu_hlo_support_checker_test.cc"],
index 10b4c3d..0645fbb 100644 (file)
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -113,8 +115,17 @@ Status RunCudnnConvolution(
 
   // cuDNN's convolution APIs support the BDYX layout for activations/output and
   // the OIYX layout for weights.
+  DataLayout input_dl;
+  FilterLayout filter_dl;
+  DataLayout output_dl;
+
+  TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
+                      XlaConvLayoutsToStreamExecutorLayouts(
+                          dnums, input_shape.layout(), filter_shape.layout(),
+                          output_shape.layout()));
+
   BatchDescriptor input_descriptor(effective_num_dimensions);
-  input_descriptor.set_layout(DataLayout::kBatchDepthYX)
+  input_descriptor.set_layout(input_dl)
       .set_feature_map_count(
           input_shape.dimensions(dnums.input_feature_dimension()))
       .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
@@ -126,7 +137,7 @@ Status RunCudnnConvolution(
   }
 
   FilterDescriptor filter_descriptor(effective_num_dimensions);
-  filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
+  filter_descriptor.set_layout(filter_dl)
       .set_input_feature_map_count(
           filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
       .set_output_feature_map_count(
@@ -149,7 +160,7 @@ Status RunCudnnConvolution(
   }
 
   BatchDescriptor output_descriptor(effective_num_dimensions);
-  output_descriptor.set_layout(DataLayout::kBatchDepthYX)
+  output_descriptor.set_layout(output_dl)
       .set_feature_map_count(
           output_shape.dimensions(dnums.output_feature_dimension()))
       .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
index 1445684..5ef422c 100644 (file)
@@ -202,18 +202,28 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
     pipeline.AddInvariantChecker<HloVerifier>();
     pipeline.AddPass<CudnnConvolutionRewriter>();
     pipeline.AddPass<PadInsertion>();
+    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+  }
+
+  {
+    HloPassPipeline pipeline("layout_assignment");
+    pipeline.AddPass<GpuLayoutAssignment>(
+        hlo_module->mutable_device_entry_computation_layout(), stream_exec);
+
+    // The LayoutAssignment pass may leave behind kCopy instructions which are
+    // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+    pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+        /*is_layout_sensitive=*/true,
+        /*valid_bitcast_callback=*/[](const Shape&, const Shape&) {
+          return true;
+        });
 
     // Choose the fastest algorithm for each conv.
     //
-    // In theory doing this here is way too early: It needs to happen after
-    // layout assignment, because the layout of the inputs/outputs affects the
-    // speed of the conv.  But currently we only allow only one input/output
-    // layout when calling cudnn, so there's no ambiguity.
-    //
-    // We pick the algorithm at this early stage so we can generate better HLO.
-    // After CudnnConvolutionRewriter, our convolutions are CustomCalls which
-    // return a tuple (conv_result, scratch_memory), and the each conv uses 0
-    // bytes of scratch:
+    // We pick the algorithm before fusion so we can generate better HLO. After
+    // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a
+    // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+    // scratch:
     //
     //   customcall = (f32[...], f32[0])
     //   return gte(customcall, 0)
@@ -229,35 +239,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
     // The new tuple and gte instructions then be simplified away, because
     // nobody is expected to use the scratch value.
     //
-    // However, if we were to run CudnnConvolutionAlgorithmPicker after layout
-    // assignment, fusion would already have run, and the gte(customcall, 0)
-    // would probably already be into a fusion node.  We can't simplify across
-    // HloComputation boundaries, so in this case we wouldn't be able to
-    // simplify away the new_tuple bits.
-    //
-    // We'll need to revisit this if we ever allow multiple layouts for the
-    // inputs/outputs of a cudnn convolution.
+    // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion
+    // the gte(customcall, 0) would probably already be into a fusion node.  We
+    // can't simplify across HloComputation boundaries, so in this case we
+    // wouldn't be able to simplify away the new_tuple bits.
     pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
                                                       device_allocator);
     // Clean up new_tuple described above.
     pipeline.AddPass<TupleSimplifier>();
-    pipeline.AddPass<HloDCE>();
-
-    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
-  }
-
-  {
-    HloPassPipeline pipeline("layout_assignment");
-    pipeline.AddPass<GpuLayoutAssignment>(
-        hlo_module->mutable_device_entry_computation_layout());
 
-    // The LayoutAssignment pass may leave behind kCopy instructions which are
-    // duplicate or NOPs, so remove them with algebraic simplification and CSE.
-    pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
-        /*is_layout_sensitive=*/true,
-        /*valid_bitcast_callback=*/[](const Shape&, const Shape&) {
-          return true;
-        });
     pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
index 89f1e62..1784577 100644 (file)
@@ -18,31 +18,72 @@ limitations under the License.
 #include <memory>
 
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/window_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 
 namespace xla {
 namespace gpu {
 
-// cuDNN convolutions are called with specific layouts on the input, output,
-// and filter:
-//
-//   input: DataLayout::kBatchDepthYX
-//   output: DataLayout::kBatchDepthYX
-//   filter: FilterLayout::kOutputInputYX
-//
-// The order dimensions in the constant name is major-to-minor (eg, the
-// most-major dimension of the input is batch, most-minor is X). The
-// specific dimension numbers these named dimensions correspond to is
-// determined by the ConvolutionDimensionNumbers argument. Y is spatial
-// dimension 0, and X is spatial dimension 1.
-//
-// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls.
-static Status AddBackendConstraintsToDnnConvCustomCall(
+using stream_executor::dnn::DataLayout;
+using stream_executor::dnn::FilterLayout;
+
+static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
+  int major, minor;
+  CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
+                                                                       &minor));
+  return major >= 7;
+}
+
+// Returns (input, filter, output) layouts.
+static std::tuple<DataLayout, FilterLayout, DataLayout>
+HeuristicLayoutAssignment(const HloInstruction* instr,
+                          stream_executor::StreamExecutor* stream_executor) {
+  // DataLayout and FilterLayout uses weird enum names. Translations:
+  //   N <=> Batch or Output
+  //   C <=> Depth or Input
+  //   H <=> Y
+  //   W <=> X
+  //
+  // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW.
+
+  // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
+  // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
+  // changes, as well as the hardware updates.
+  if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
+        IsVoltaOrLater(*stream_executor))) {
+    return std::make_tuple(DataLayout::kBatchDepthYX,
+                           FilterLayout::kOutputInputYX,
+                           DataLayout::kBatchDepthYX);
+  }
+  VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
+  // For BackwardInput that has stride, full NHWC layouts run significantly
+  // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
+  //
+  // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
+  // NHWC).
+  if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
+      window_util::HasStride(instr->window())) {
+    return std::make_tuple(DataLayout::kBatchYXDepth,
+                           FilterLayout::kOutputInputYX,
+                           DataLayout::kBatchDepthYX);
+  }
+  return std::make_tuple(DataLayout::kBatchYXDepth,
+                         FilterLayout::kOutputYXInput,
+                         DataLayout::kBatchYXDepth);
+}
+
+// Adds layout constraints on the cudnn custom-call instruction. The layout
+// constraints are represented in terms of minor_to_major fields of both
+// operands and the output shape. Depending on the underlying algorithm, one of
+// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
+Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
     HloInstruction* instr, LayoutConstraints* constraints) {
   CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
   Shape input_shape;
@@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall(
                << instr->custom_call_target();
   }
 
-  // Construct minor-to-major dimension orders for operands and result.
-  // cuDNN's convolution APIs support the BDYX layout for activations/output
-  // and the OIYX layout for weights.
-  // TODO(b/29399649): Be more flexible about handling layouts of cuDNN
-  // calls after we switch to cuDNN v5.
-  const ConvolutionDimensionNumbers& dimension_numbers =
-      instr->convolution_dimension_numbers();
-  std::vector<int64> input_layout;
-  for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0;
-       --i) {
-    input_layout.push_back(dimension_numbers.input_spatial_dimensions(i));
-  }
-  input_layout.push_back(dimension_numbers.input_feature_dimension());
-  input_layout.push_back(dimension_numbers.input_batch_dimension());
-  *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout);
-
-  std::vector<int64> filter_layout;
-  for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0;
-       --i) {
-    filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i));
-  }
-  filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension());
-  filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension());
-  *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout);
-
-  std::vector<int64> output_layout;
-  for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0;
-       --i) {
-    output_layout.push_back(dimension_numbers.output_spatial_dimensions(i));
+  {
+    DataLayout input;
+    FilterLayout filter;
+    DataLayout output;
+    if (ConvUseLayoutHeuristic(instr->GetModule()->config())) {
+      std::tie(input, filter, output) =
+          HeuristicLayoutAssignment(instr, stream_executor_);
+    } else {
+      input = DataLayout::kBatchDepthYX;
+      filter = FilterLayout::kOutputInputYX;
+      output = DataLayout::kBatchDepthYX;
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
+                 *output_shape.mutable_layout()),
+        StreamExecutorConvLayoutsToXlaLayouts(
+            instr->convolution_dimension_numbers(), input, filter, output));
   }
-  output_layout.push_back(dimension_numbers.output_feature_dimension());
-  output_layout.push_back(dimension_numbers.output_batch_dimension());
-  *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout);
 
   // The custom call returns a tuple of (actual_result, scratch_buffer);
   // call_result_buf is the logical buffer for actual_result, the thing that
index 86a3a71..ce24af1 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/computation_layout.h"
 #include "tensorflow/compiler/xla/service/layout_assignment.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
 
 namespace xla {
 namespace gpu {
@@ -27,8 +28,10 @@ namespace gpu {
 // layout constraints for operands and results of library calls.
 class GpuLayoutAssignment : public LayoutAssignment {
  public:
-  explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
-      : LayoutAssignment(entry_computation_layout) {}
+  explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
+                               se::StreamExecutor* stream_executor)
+      : LayoutAssignment(entry_computation_layout),
+        stream_executor_(stream_executor) {}
   ~GpuLayoutAssignment() override {}
 
  protected:
@@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment {
       LayoutConstraints* constraints) override;
   bool CustomCallRequiresMajorFirstLayout(
       const HloInstruction* instruction) override;
+
+ private:
+  Status AddBackendConstraintsToDnnConvCustomCall(
+      HloInstruction* instr, LayoutConstraints* constraints);
+
+  se::StreamExecutor* stream_executor_;
 };
 
 }  // namespace gpu
index 4c45d2e..e48165c 100644 (file)
@@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
         *computation_layout.mutable_result_layout() =
             ShapeLayout(result_shape_with_layout);
 
-        GpuLayoutAssignment layout_assignment(&computation_layout);
+        GpuLayoutAssignment layout_assignment(
+            &computation_layout, backend().default_stream_executor());
         EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         for (const HloInstruction* operand : add->operands()) {
@@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
         *computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
       }
 
-      GpuLayoutAssignment layout_assignment(&computation_layout);
+      GpuLayoutAssignment layout_assignment(
+          &computation_layout, backend().default_stream_executor());
       EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
       // The first operand to batchnorm should have the same layout as the
@@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
                 {result_shape, offset_scale_shape, offset_scale_shape}));
       }
 
-      GpuLayoutAssignment layout_assignment(&computation_layout);
+      GpuLayoutAssignment layout_assignment(
+          &computation_layout, backend().default_stream_executor());
       EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
       // The first operand to batchnorm should have the same layout as the
@@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
                   {result_shape, scale_shape, scale_shape}));
         }
 
-        GpuLayoutAssignment layout_assignment(&computation_layout);
+        GpuLayoutAssignment layout_assignment(
+            &computation_layout, backend().default_stream_executor());
         EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         // The first and fourth operands to the batchnorm call should have the
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc
new file mode 100644 (file)
index 0000000..174aaf1
--- /dev/null
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+namespace gpu {
+
+bool ConvUseLayoutHeuristic(const HloModuleConfig& config) {
+  return config.debug_options().xla_backend_extra_options().count(
+      "xla_gpu_experimental_conv_use_layout_heuristic");
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h
new file mode 100644 (file)
index 0000000..498d4a9
--- /dev/null
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+// Helper functions for querying options that are specific to the GPU backend.
+
+namespace xla {
+namespace gpu {
+
+// Returns true if we should use heuristics to assign convolution layouts, as
+// opposed to always assigning NCHW.
+bool ConvUseLayoutHeuristic(const HloModuleConfig& config);
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
new file mode 100644 (file)
index 0000000..a50ddf6
--- /dev/null
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+
+namespace xla {
+namespace gpu {
+
+using stream_executor::dnn::DataLayout;
+using stream_executor::dnn::DataLayoutString;
+using stream_executor::dnn::FilterLayout;
+using stream_executor::dnn::FilterLayoutString;
+
+StatusOr<std::tuple<Layout, Layout, Layout>>
+StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
+                                      DataLayout input, FilterLayout filter,
+                                      DataLayout output) {
+  std::vector<int64> input_layout;
+  switch (input) {
+    case DataLayout::kBatchDepthYX:
+      input_layout.push_back(dnums.input_batch_dimension());
+      input_layout.push_back(dnums.input_feature_dimension());
+      input_layout.insert(input_layout.end(),
+                          dnums.input_spatial_dimensions().begin(),
+                          dnums.input_spatial_dimensions().end());
+      break;
+    case DataLayout::kBatchYXDepth:
+      input_layout.push_back(dnums.input_batch_dimension());
+      input_layout.insert(input_layout.end(),
+                          dnums.input_spatial_dimensions().begin(),
+                          dnums.input_spatial_dimensions().end());
+      input_layout.push_back(dnums.input_feature_dimension());
+      break;
+    default:
+      return tensorflow::errors::Internal("Invalid input layout: ",
+                                          DataLayoutString(input));
+  }
+
+  std::vector<int64> filter_layout;
+  switch (filter) {
+    case FilterLayout::kOutputInputYX:
+      filter_layout.push_back(dnums.kernel_output_feature_dimension());
+      filter_layout.push_back(dnums.kernel_input_feature_dimension());
+      filter_layout.insert(filter_layout.end(),
+                           dnums.kernel_spatial_dimensions().begin(),
+                           dnums.kernel_spatial_dimensions().end());
+      break;
+    case FilterLayout::kOutputYXInput:
+      filter_layout.push_back(dnums.kernel_output_feature_dimension());
+      filter_layout.insert(filter_layout.end(),
+                           dnums.kernel_spatial_dimensions().begin(),
+                           dnums.kernel_spatial_dimensions().end());
+      filter_layout.push_back(dnums.kernel_input_feature_dimension());
+      break;
+    default:
+      return tensorflow::errors::Internal("Invalid filter layout: ",
+                                          FilterLayoutString(filter));
+  }
+
+  std::vector<int64> output_layout;
+  switch (output) {
+    case DataLayout::kBatchDepthYX:
+      output_layout.push_back(dnums.output_batch_dimension());
+      output_layout.push_back(dnums.output_feature_dimension());
+      output_layout.insert(output_layout.end(),
+                           dnums.output_spatial_dimensions().begin(),
+                           dnums.output_spatial_dimensions().end());
+      break;
+    case DataLayout::kBatchYXDepth:
+      output_layout.push_back(dnums.output_batch_dimension());
+      output_layout.insert(output_layout.end(),
+                           dnums.output_spatial_dimensions().begin(),
+                           dnums.output_spatial_dimensions().end());
+      output_layout.push_back(dnums.output_feature_dimension());
+      break;
+    default:
+      return tensorflow::errors::Internal("Invalid output layout: ",
+                                          DataLayoutString(output));
+  }
+
+  return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
+                         LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
+                         LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
+}
+
+StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
+XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
+                                      const Layout& input, const Layout& filter,
+                                      const Layout& output) {
+  Layout nchw_input, nchw_filter, nchw_output;
+  std::tie(nchw_input, nchw_filter, nchw_output) =
+      StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
+                                            FilterLayout::kOutputInputYX,
+                                            DataLayout::kBatchDepthYX)
+          .ConsumeValueOrDie();
+
+  Layout nhwc_input, nhwc_filter, nhwc_output;
+  std::tie(nhwc_input, nhwc_filter, nhwc_output) =
+      StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
+                                            FilterLayout::kOutputYXInput,
+                                            DataLayout::kBatchYXDepth)
+          .ConsumeValueOrDie();
+
+  DataLayout input_layout;
+  if (LayoutUtil::Equal(input, nchw_input)) {
+    input_layout = DataLayout::kBatchDepthYX;
+  } else if (LayoutUtil::Equal(input, nhwc_input)) {
+    input_layout = DataLayout::kBatchYXDepth;
+  } else {
+    return tensorflow::errors::Internal("Invalid input layout: ",
+                                        input.ShortDebugString());
+  }
+
+  FilterLayout filter_layout;
+  if (LayoutUtil::Equal(filter, nchw_filter)) {
+    filter_layout = FilterLayout::kOutputInputYX;
+  } else if (LayoutUtil::Equal(filter, nhwc_filter)) {
+    filter_layout = FilterLayout::kOutputYXInput;
+  } else {
+    return tensorflow::errors::Internal("Invalid filter layout: ",
+                                        filter.ShortDebugString());
+  }
+
+  DataLayout output_layout;
+  if (LayoutUtil::Equal(output, nchw_output)) {
+    output_layout = DataLayout::kBatchDepthYX;
+  } else if (LayoutUtil::Equal(output, nhwc_output)) {
+    output_layout = DataLayout::kBatchYXDepth;
+  } else {
+    return tensorflow::errors::Internal("Invalid output layout: ",
+                                        output.ShortDebugString());
+  }
+
+  return std::make_tuple(input_layout, filter_layout, output_layout);
+}
+}  // namespace gpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
new file mode 100644 (file)
index 0000000..8218f4f
--- /dev/null
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+// Helper functions for interacting with StreamExecutor.
+
+namespace xla {
+namespace gpu {
+
+// Returns (input, filter, output) XLA Layout protos given the StreamExecutor
+// layouts.
+StatusOr<std::tuple<Layout, Layout, Layout>>
+StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
+                                      stream_executor::dnn::DataLayout input,
+                                      stream_executor::dnn::FilterLayout filter,
+                                      stream_executor::dnn::DataLayout output);
+
+// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts.
+StatusOr<std::tuple<stream_executor::dnn::DataLayout,
+                    stream_executor::dnn::FilterLayout,
+                    stream_executor::dnn::DataLayout>>
+XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
+                                      const Layout& input, const Layout& filter,
+                                      const Layout& output);
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
index fd54ac7..1a12fd0 100644 (file)
@@ -776,30 +776,42 @@ xla_test(
     ],
 )
 
+CONVOLUTION_TEST_DEPS = [
+    "//tensorflow/compiler/xla:array2d",
+    "//tensorflow/compiler/xla:array4d",
+    "//tensorflow/compiler/xla:literal_util",
+    "//tensorflow/compiler/xla:reference_util",
+    "//tensorflow/compiler/xla:shape_util",
+    "//tensorflow/compiler/xla:statusor",
+    "//tensorflow/compiler/xla:util",
+    "//tensorflow/compiler/xla:xla_data_proto",
+    "//tensorflow/compiler/xla/client:global_data",
+    "//tensorflow/compiler/xla/client:local_client",
+    "//tensorflow/compiler/xla/client:padding",
+    "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+    "//tensorflow/compiler/xla/tests:client_library_test_base",
+    "//tensorflow/compiler/xla/tests:literal_test_util",
+    "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+    "//tensorflow/core:lib",
+    "//tensorflow/core:test",
+]
+
 xla_test(
     name = "convolution_test",
     timeout = "long",
     srcs = ["convolution_test.cc"],
     shard_count = 25,
-    deps = [
-        "//tensorflow/compiler/xla:array2d",
-        "//tensorflow/compiler/xla:array4d",
-        "//tensorflow/compiler/xla:literal_util",
-        "//tensorflow/compiler/xla:reference_util",
-        "//tensorflow/compiler/xla:shape_util",
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla:util",
-        "//tensorflow/compiler/xla:xla_data_proto",
-        "//tensorflow/compiler/xla/client:global_data",
-        "//tensorflow/compiler/xla/client:local_client",
-        "//tensorflow/compiler/xla/client:padding",
-        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
-        "//tensorflow/compiler/xla/tests:client_library_test_base",
-        "//tensorflow/compiler/xla/tests:literal_test_util",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:test",
-    ],
+    deps = CONVOLUTION_TEST_DEPS,
+)
+
+xla_test(
+    name = "convolution_test_gpu_alternative_layout",
+    timeout = "long",
+    srcs = ["convolution_test.cc"],
+    backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_use_layout_heuristic"]},
+    backends = ["gpu"],
+    shard_count = 25,
+    deps = CONVOLUTION_TEST_DEPS,
 )
 
 xla_test(