[TF:XLA] Add INTEL MKL_DNN Conv2d method to XLA/CPU backend
authorTony Wang <tonywy@google.com>
Thu, 26 Apr 2018 20:30:15 +0000 (13:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 20:33:48 +0000 (13:33 -0700)
The INTEL MKL_DNN provides 32-bit Conv2d method. With INTEL_MKL flag set,
XLA backend emits runtime call to MKL_DNN Conv2d instead of Eigen.

PiperOrigin-RevId: 194445212

tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
tensorflow/compiler/xla/service/cpu/cpu_runtime.h
tensorflow/compiler/xla/service/cpu/ir_emitter.cc
tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h [new file with mode: 0644]
tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc

index 04fda3b..cef4eba 100644 (file)
@@ -169,6 +169,7 @@ cc_library(
         ":orc_jit_memory_mapper",
         ":runtime_fp16",
         ":runtime_conv2d",
+        ":runtime_conv2d_mkl",
         ":runtime_fft",
         ":runtime_fork_join",
         ":runtime_matmul",
@@ -471,6 +472,27 @@ cc_library(
 )
 
 cc_library(
+    name = "runtime_conv2d_mkl",
+    srcs = [
+        "runtime_conv2d_mkl.cc",
+    ],
+    hdrs = ["runtime_conv2d_mkl.h"],
+    copts = runtime_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        ":runtime_conv2d",
+        ":runtime_single_threaded_conv2d",
+        "//tensorflow/compiler/xla:executable_run_options",
+        "//tensorflow/core:framework_lite",
+        "//tensorflow/core/kernels:eigen_helpers",
+        "//third_party/eigen3",
+    ] + if_mkl([
+        "@mkl_dnn",
+        "//third_party/mkl:intel_binary_blob",
+    ]),
+)
+
+cc_library(
     name = "runtime_fft",
     srcs = [
         "runtime_fft.cc",
index 872b0be..215405f 100644 (file)
@@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF32SymbolName =
     "__xla_cpu_runtime_EigenMatMulF32";
 extern const char* const kEigenMatMulF64SymbolName =
     "__xla_cpu_runtime_EigenMatMulF64";
+extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
 extern const char* const kMKLMatMulF32SymbolName =
     "__xla_cpu_runtime_MKLMatMulF32";
 extern const char* const kMKLMatMulF64SymbolName =
index e392e23..1dce6ef 100644 (file)
@@ -44,6 +44,7 @@ namespace runtime {
 extern const char* const kEigenMatMulF16SymbolName;
 extern const char* const kEigenMatMulF32SymbolName;
 extern const char* const kEigenMatMulF64SymbolName;
+extern const char* const kMKLConvF32SymbolName;
 extern const char* const kMKLMatMulF32SymbolName;
 extern const char* const kMKLMatMulF64SymbolName;
 extern const char* const kMKLSingleThreadedMatMulF32SymbolName;
index 0b08ad8..d582b5a 100644 (file)
@@ -854,6 +854,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
   const ConvolutionDimensionNumbers& dnums =
       convolution->convolution_dimension_numbers();
 
+  // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
+  // different data layouts.
   if (PotentiallyImplementedAsEigenConvolution(*convolution)) {
     const Shape& lhs_shape = lhs->shape();
     const Shape& rhs_shape = rhs->shape();
@@ -942,16 +944,26 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
            int64_type,    int64_type,  int64_type,  int64_type},
           /*isVarArg=*/false);
-      bool multi_threaded_eigen =
+      bool multi_threaded =
           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
+      bool use_mkl_dnn =
+          hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
+
+      // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
+      // potential race condition by setting the omp_num_threads.
       const char* fn_name =
           primitive_type == F16
-              ? (multi_threaded_eigen
+              ? (multi_threaded
                      ? runtime::kEigenConvF16SymbolName
                      : runtime::kEigenSingleThreadedConvF16SymbolName)
-              : (multi_threaded_eigen
-                     ? runtime::kEigenConvF32SymbolName
+              : (multi_threaded
+                     ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
+                                    : runtime::kEigenConvF32SymbolName)
                      : runtime::kEigenSingleThreadedConvF32SymbolName);
+      if (!multi_threaded && use_mkl_dnn) {
+        LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
+                        "conv2d function.";
+      }
       llvm::Function* conv_func = llvm::cast<llvm::Function>(
           module_->getOrInsertFunction(fn_name, conv_type));
       conv_func->setCallingConv(llvm::CallingConv::C);
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc
new file mode 100644 (file)
index 0000000..c60580d
--- /dev/null
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
+#include <iostream>
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int64;
+
+#ifdef INTEL_MKL
+#include <omp.h>
+#include "mkldnn.hpp"
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+
+namespace {
+
+// Downcast an int64 to int and check if value is in range.
+int ToInt(int64 input) {
+  int output = static_cast<int>(input);
+  if (static_cast<int64>(output) != input) {
+    std::cerr << "Error occurred in downcasting int64 to int32: Value " << input
+              << " is out-of-range for type int32. \n";
+    exit(1);
+  }
+  return output;
+}
+
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+using mkldnn::engine;
+using mkldnn::memory;
+using mkldnn::padding_kind;
+using mkldnn::primitive;
+using mkldnn::prop_kind;
+using mkldnn::reorder;
+using mkldnn::stream;
+
+template <typename EigenDevice, typename ScalarType>
+void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
+                 ScalarType* rhs, int64 input_batch, int64 input_rows,
+                 int64 input_cols, int64 input_channels, int64 kernel_rows,
+                 int64 kernel_cols, int64 kernel_channels, int64 kernel_filters,
+                 int64 output_rows, int64 output_cols, int64 row_stride,
+                 int64 col_stride, int64 padding_top, int64 padding_bottom,
+                 int64 padding_left, int64 padding_right,
+                 int64 lhs_row_dilation, int64 lhs_col_dilation,
+                 int64 rhs_row_dilation, int64 rhs_col_dilation) {
+  auto cpu_engine = engine(engine::cpu, 0);
+
+  // Create a vector primitive to hold the network.
+  std::vector<primitive> net;
+
+  // Since memory::dims takes int for each dimension, we downcast the int64
+  // values to int using the ToInt function defined above.
+  memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels),
+                                ToInt(input_rows), ToInt(input_cols)};
+  memory::dims conv1_weights_dim = {ToInt(kernel_filters),
+                                    ToInt(kernel_channels), ToInt(kernel_rows),
+                                    ToInt(kernel_cols)};
+  memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters),
+                                ToInt(output_rows), ToInt(output_cols)};
+  memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)};
+  // Note: In MKL_DNN dilation starts from 0.
+  memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1),
+                                ToInt(rhs_col_dilation - 1)};
+  memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)};
+  memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)};
+
+  // Create memory for user data. Input and output data have format of NHWC and
+  // kernel data has format of HWIO.
+  // Note that as a convention in MKL-DNN, the dimensions of the data is always
+  // described in NCHW/IOHW, regardless of the actual layout of the data.
+  auto user_src_memory =
+      memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc},
+              cpu_engine},
+             lhs);
+  auto user_weights_memory = memory(
+      {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio},
+       cpu_engine},
+      rhs);
+  auto user_dst_memory =
+      memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc},
+              cpu_engine},
+             out);
+
+  // Create memory descriptors for convolution data with no specified format for
+  // best performance.
+  auto conv1_src_mem_desc = memory::desc(
+      {conv1_src_dim}, memory::data_type::f32, memory::format::any);
+  auto conv1_weights_mem_desc = memory::desc(
+      {conv1_weights_dim}, memory::data_type::f32, memory::format::any);
+  auto conv1_dst_mem_desc = memory::desc(
+      {conv1_dst_dim}, memory::data_type::f32, memory::format::any);
+
+  // Create a convolution.
+  auto conv1_desc = convolution_forward::desc(
+      prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc,
+      conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates,
+      conv1_padding_l, conv1_padding_r, padding_kind::zero);
+  auto conv1_prim_desc =
+      convolution_forward::primitive_desc(conv1_desc, cpu_engine);
+
+  // Create reorders for data and weights if layout requested by convolution is
+  // different from NCHW/OIHW.
+  auto conv1_src_memory = user_src_memory;
+  if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) !=
+      user_src_memory.get_primitive_desc()) {
+    conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc());
+    net.push_back(reorder(user_src_memory, conv1_src_memory));
+  }
+
+  auto conv1_weights_memory = user_weights_memory;
+  if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) !=
+      user_weights_memory.get_primitive_desc()) {
+    conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc());
+    net.push_back(reorder(user_weights_memory, conv1_weights_memory));
+  }
+
+  // Check if output need layout conversion. If yes, create memory for
+  // intermediate layer of conv1_dst_memory.
+  bool need_output_conversion =
+      (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) !=
+       user_dst_memory.get_primitive_desc());
+  auto conv1_dst_memory = need_output_conversion
+                              ? memory(conv1_prim_desc.dst_primitive_desc())
+                              : user_dst_memory;
+
+  // Create convolution primitive and add it to net.
+  net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory,
+                                    conv1_weights_memory, conv1_dst_memory));
+  if (need_output_conversion) {
+    net.push_back(reorder(conv1_dst_memory, user_dst_memory));
+  }
+  stream(stream::kind::eager).submit(net).wait();
+}
+}  // namespace
+#endif  // INTEL_MKL
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConvF32(
+    const void* run_options_ptr, float* out, float* lhs, float* rhs,
+    int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels,
+    int64 kernel_rows, int64 kernel_cols, int64 kernel_channels,
+    int64 kernel_filters, int64 output_rows, int64 output_cols,
+    int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
+    int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
+    int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
+#ifdef INTEL_MKL
+  // Since MKL_DNN cannot handle transposed convolution, this is handled by
+  // Eigen.
+  if (lhs_row_dilation > 1 || lhs_col_dilation > 1) {
+    __xla_cpu_runtime_EigenConvF32(
+        run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols,
+        input_channels, kernel_rows, kernel_cols, kernel_channels,
+        kernel_filters, output_rows, output_cols, row_stride, col_stride,
+        padding_top, padding_bottom, padding_left, padding_right,
+        lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
+  } else {
+    MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols,
+                input_channels, kernel_rows, kernel_cols, kernel_channels,
+                kernel_filters, output_rows, output_cols, row_stride,
+                col_stride, padding_top, padding_bottom, padding_left,
+                padding_right, lhs_row_dilation, lhs_col_dilation,
+                rhs_row_dilation, rhs_col_dilation);
+  }
+#else
+  std::cerr << "Attempt to call MKL Conv2D runtime library without defining "
+               "INTEL_MKL. Add --config=mkl to build with MKL.";
+  exit(1);
+#endif  // INTEL_MKL
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h
new file mode 100644 (file)
index 0000000..b239e71
--- /dev/null
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_
+
+#include <iostream>
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+extern void __xla_cpu_runtime_MKLConvF32(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+    float* lhs, float* rhs, tensorflow::int64 input_batch,
+    tensorflow::int64 input_rows, tensorflow::int64 input_cols,
+    tensorflow::int64 input_channels, tensorflow::int64 kernel_rows,
+    tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels,
+    tensorflow::int64 kernel_filters, tensorflow::int64 output_rows,
+    tensorflow::int64 output_cols, tensorflow::int64 row_stride,
+    tensorflow::int64 col_stride, tensorflow::int64 padding_top,
+    tensorflow::int64 padding_bottom, tensorflow::int64 padding_left,
+    tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation,
+    tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation,
+    tensorflow::int64 rhs_col_dilation);
+}
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_
index b7ce5bb..ff6f0a9 100644 (file)
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
@@ -178,6 +179,7 @@ bool RegisterKnownJITSymbols() {
 
   REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
   REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
+  REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);