Automated g4 rollback of changelist 191605505
authorTony Wang <tonywy@google.com>
Fri, 6 Apr 2018 00:24:33 +0000 (17:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 00:26:58 +0000 (17:26 -0700)
PiperOrigin-RevId: 191824447

tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
tensorflow/compiler/xla/service/cpu/cpu_runtime.h
tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h [new file with mode: 0644]
tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
tensorflow/compiler/xla/xla.proto

index c8ed3e3..f037663 100644 (file)
@@ -40,6 +40,9 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
   flags->set_xla_cpu_multi_thread_eigen(true);
   flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
   flags->set_xla_eliminate_hlo_implicit_broadcast(true);
+#ifdef INTEL_MKL
+  flags->set_xla_cpu_use_mkl_dnn(true);
+#endif  // INTEL_MKL
 
   // Set cudnn batchnorm off by default; it does not provide a performance win
   // on average.
@@ -288,6 +291,10 @@ void AllocateFlags() {
           flag_values->xla_gpu_use_cudnn_batchnorm(),
           "Allows the GPU backend to implement batchnorm HLOs using cudnn, "
           "rather than expanding them to a soup of HLOs."),
+      tensorflow::Flag("xla_cpu_use_mkl_dnn",
+                       bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
+                       flag_values->xla_cpu_use_mkl_dnn(),
+                       "Generate calls to MKL-DNN in the CPU backend."),
   });
   ParseFlagsFromEnv(*flag_objects);
 }
index 966e2d0..246b802 100644 (file)
@@ -18,6 +18,10 @@ load(":build_defs.bzl", "runtime_copts")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
 load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
+load(
+    "//third_party/mkl:build_defs.bzl",
+    "if_mkl",
+)
 
 # Filegroup used to collect source files for dependency checking.
 filegroup(
@@ -170,6 +174,7 @@ cc_library(
         ":runtime_fft",
         ":runtime_fork_join",
         ":runtime_matmul",
+        ":runtime_matmul_mkl",
         ":runtime_single_threaded_conv2d",
         ":runtime_single_threaded_matmul",
         "@llvm//:execution_engine",
@@ -539,6 +544,22 @@ cc_library(
 )
 
 cc_library(
+    name = "runtime_matmul_mkl",
+    srcs = ["runtime_matmul_mkl.cc"],
+    hdrs = ["runtime_matmul_mkl.h"],
+    copts = runtime_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/compiler/xla:executable_run_options",
+        "//tensorflow/core:framework_lite",
+        "//third_party/eigen3",
+    ] + if_mkl([
+        "//third_party/mkl:intel_binary_blob",
+        "@mkl_dnn",
+    ]),
+)
+
+cc_library(
     name = "runtime_single_threaded_conv2d",
     srcs = [
         "runtime_conv2d_impl.h",
@@ -584,10 +605,12 @@ cc_library(
 tf_cc_test(
     name = "cpu_runtime_test",
     srcs = ["cpu_runtime_test.cc"],
+    shard_count = 10,
     tags = ["optonly"],
     deps = [
         ":cpu_runtime",
         ":runtime_matmul",
+        ":runtime_matmul_mkl",
         ":runtime_single_threaded_matmul",
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:types",
index 9a3bd68..872b0be 100644 (file)
@@ -37,6 +37,14 @@ extern const char* const kEigenMatMulF32SymbolName =
     "__xla_cpu_runtime_EigenMatMulF32";
 extern const char* const kEigenMatMulF64SymbolName =
     "__xla_cpu_runtime_EigenMatMulF64";
+extern const char* const kMKLMatMulF32SymbolName =
+    "__xla_cpu_runtime_MKLMatMulF32";
+extern const char* const kMKLMatMulF64SymbolName =
+    "__xla_cpu_runtime_MKLMatMulF64";
+extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
+    "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
+extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
+    "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
 extern const char* const kEigenConvF16SymbolName =
     "__xla_cpu_runtime_EigenConvF16";
 extern const char* const kEigenConvF32SymbolName =
index e61d6ea..e392e23 100644 (file)
@@ -44,6 +44,10 @@ namespace runtime {
 extern const char* const kEigenMatMulF16SymbolName;
 extern const char* const kEigenMatMulF32SymbolName;
 extern const char* const kEigenMatMulF64SymbolName;
+extern const char* const kMKLMatMulF32SymbolName;
+extern const char* const kMKLMatMulF64SymbolName;
+extern const char* const kMKLSingleThreadedMatMulF32SymbolName;
+extern const char* const kMKLSingleThreadedMatMulF64SymbolName;
 extern const char* const kEigenConvF16SymbolName;
 extern const char* const kEigenConvF32SymbolName;
 extern const char* const kEigenFftSymbolName;
index f385829..2ac950e 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
@@ -130,25 +131,23 @@ MatMulShape MatMulShapes[] = {
 // * transpose_lhs
 // * transpose_rhs
 // * single_threaded
-using EigenMatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
+using MatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
 
-class EigenMatMulTest
-    : public CpuRuntimeTest,
-      public ::testing::WithParamInterface<EigenMatMulTestParam> {
+class EigenMatMulTest : public CpuRuntimeTest,
+                        public ::testing::WithParamInterface<MatMulTestParam> {
  public:
-  static string Name(
-      const ::testing::TestParamInfo<EigenMatMulTestParam>& info) {
+  static string Name(const ::testing::TestParamInfo<MatMulTestParam>& info) {
     MatMulShape shape = std::get<0>(info.param);
     bool transpose_lhs = std::get<1>(info.param);
     bool transpose_rhs = std::get<2>(info.param);
     bool single_threaded = std::get<3>(info.param);
 
     return tensorflow::strings::Printf(
-        "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
+        "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
         transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
         single_threaded ? "single" : "multi");
   }
-};  // namespace xla
+};
 
 TEST_P(EigenMatMulTest, DoIt) {
   MatMulShape shape = std::get<0>(GetParam());
@@ -169,5 +168,74 @@ INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest,
                                            ::testing::Bool()),
                         EigenMatMulTest::Name);
 
+#ifdef INTEL_MKL
+class MKLMatMulTest : public CpuRuntimeTest,
+                      public ::testing::WithParamInterface<MatMulTestParam> {
+ public:
+  static string Name(const ::testing::TestParamInfo<MatMulTestParam>& info) {
+    MatMulShape shape = std::get<0>(info.param);
+    bool transpose_lhs = std::get<1>(info.param);
+    bool transpose_rhs = std::get<2>(info.param);
+    bool single_threaded = std::get<3>(info.param);
+
+    return tensorflow::strings::Printf(
+        "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
+        transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
+        single_threaded ? "single" : "multi");
+  }
+};
+
+std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a,
+                                                  const Array2D<float>& b,
+                                                  bool transpose_lhs,
+                                                  bool transpose_rhs,
+                                                  bool single_threaded) {
+  CHECK_EQ(a.width(), b.height());
+  int64 m = a.height();
+  int64 n = b.width();
+  int64 k = a.width();
+
+  // The MKL matmul runtime function expects the matrix to be in column major
+  // order and array2d is in row-major order. Create transposes of a and b. The
+  // 'data' buffer in the transposed array is the original array in column major
+  // order.
+  auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs);
+  auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs);
+
+  // Since we're going to transpose c before returning it, swap the order of the
+  // dimension sizes to ensure the returned array is properly dimensioned.
+  auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+  if (single_threaded) {
+    __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
+        nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
+        m, n, k, transpose_lhs, transpose_rhs);
+  } else {
+    __xla_cpu_runtime_MKLMatMulF32(nullptr, c_transpose->data(),
+                                   a_transpose->data(), b_transpose->data(), m,
+                                   n, k, transpose_lhs, transpose_rhs);
+  }
+  return MaybeTransposeArray2D(*c_transpose, true);
+}
+
+TEST_P(MKLMatMulTest, DoIt) {
+  MatMulShape shape = std::get<0>(GetParam());
+  bool transpose_lhs = std::get<1>(GetParam());
+  bool transpose_rhs = std::get<2>(GetParam());
+  bool single_threaded = std::get<3>(GetParam());
+
+  auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k);
+  auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n);
+  auto c =
+      MKLMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs, single_threaded);
+  CheckMatrixMultiply(*a, *b, *c);
+}
+
+INSTANTIATE_TEST_CASE_P(MKLMatMulTestInstantiaion, MKLMatMulTest,
+                        ::testing::Combine(::testing::ValuesIn(MatMulShapes),
+                                           ::testing::Bool(), ::testing::Bool(),
+                                           ::testing::Bool()),
+                        MKLMatMulTest::Name);
+#endif  // INTEL_MKL
+
 }  // namespace
 }  // namespace xla
index 8b1e20d..29afd8e 100644 (file)
@@ -918,28 +918,35 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
   // The two transpose_... parameters are actually booleans, but we use int32
   // to avoid target-dependent calling convention details.
 
-  bool multi_threaded_eigen =
+  bool multi_threaded =
       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
+  bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
   PrimitiveType type = target_array_.GetShape().element_type();
   llvm::Type* float_type;
   const char* fn_name;
   switch (type) {
     case F16:
-      fn_name = multi_threaded_eigen
+      fn_name = multi_threaded
                     ? runtime::kEigenMatMulF16SymbolName
                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
       float_type = ir_builder_->getHalfTy();
       break;
     case F32:
-      fn_name = multi_threaded_eigen
-                    ? runtime::kEigenMatMulF32SymbolName
-                    : runtime::kEigenSingleThreadedMatMulF32SymbolName;
+      fn_name = multi_threaded
+                    ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
+                                   : runtime::kEigenMatMulF32SymbolName)
+                    : (use_mkl_dnn
+                           ? runtime::kMKLSingleThreadedMatMulF32SymbolName
+                           : runtime::kEigenSingleThreadedMatMulF32SymbolName);
       float_type = ir_builder_->getFloatTy();
       break;
     case F64:
-      fn_name = multi_threaded_eigen
-                    ? runtime::kEigenMatMulF64SymbolName
-                    : runtime::kEigenSingleThreadedMatMulF64SymbolName;
+      fn_name = multi_threaded
+                    ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
+                                   : runtime::kEigenMatMulF64SymbolName)
+                    : (use_mkl_dnn
+                           ? runtime::kMKLSingleThreadedMatMulF64SymbolName
+                           : runtime::kEigenSingleThreadedMatMulF64SymbolName);
       float_type = ir_builder_->getDoubleTy();
       break;
     default:
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
new file mode 100644 (file)
index 0000000..92da5f7
--- /dev/null
@@ -0,0 +1,128 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
+#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
+#include "third_party/intel_mkl_ml/include/mkl_service.h"
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/types.h"
+
+#define EIGEN_USE_THREADS
+#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+namespace {
+// BLAS GEMM API for 32-bit Matrix Multiplication.
+
+// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
+// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
+// Matrix lhs, rhs and out are all colum-major.
+void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs,
+               int64 m, int64 n, int64 k, int32 transpose_lhs,
+               int32 transpose_rhs) {
+  const float alpha = 1.0f, beta = 0.0f;
+  // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
+  // respectively. For column-major matrices, the leading dimension is the
+  // stride between consecutive columns (which equals the number of rows). If
+  // the matrix is transposed, the leading dimension is the stride between
+  // consecutive rows (which equals the number of columns).
+  int lda = transpose_lhs ? k : m;
+  int ldb = transpose_rhs ? n : k;
+  int ldc = m;
+  cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
+              transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
+              lda, rhs, ldb, beta, out, ldc);
+}
+
+// BLAS GEMM API for 64-bit Matrix Multiplication.
+
+// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
+// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
+// Matrix lhs, rhs and out are all colum-major.
+void MatMulF64(const void* run_options_ptr, double* out, double* lhs,
+               double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
+               int32 transpose_rhs) {
+  const float alpha = 1.0f, beta = 0.0f;
+  // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
+  // respectively. For a column-major matrix, the leading dimension is the
+  // stride between consecutive columns (which equals the number of rows). If
+  // the matrix is transposed, the leading dimension is the stride between
+  // consecutive rows (which equals the number of columns).
+  int lda = transpose_lhs ? k : m;
+  int ldb = transpose_rhs ? n : k;
+  int ldc = m;
+  cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
+              transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
+              lda, rhs, ldb, beta, out, ldc);
+}
+
+}  // namespace
+
+void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
+                                    float* lhs, float* rhs, int64 m, int64 n,
+                                    int64 k, int32 transpose_lhs,
+                                    int32 transpose_rhs) {
+  const xla::ExecutableRunOptions* run_options =
+      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+  // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
+  // number specified in intra_op_thread_pool to MKL.
+  int prev_num_threads = mkl_set_num_threads_local(
+      run_options->intra_op_thread_pool()->numThreads());
+  MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+  // Set thread number back to the previous number.
+  mkl_set_num_threads_local(prev_num_threads);
+}
+// BLAS GEMM API for 64-bit Matrix Multiplication
+void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
+                                    double* lhs, double* rhs, int64 m, int64 n,
+                                    int64 k, int32 transpose_lhs,
+                                    int32 transpose_rhs) {
+  const xla::ExecutableRunOptions* run_options =
+      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+  // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
+  // number specified in intra_op_thread_pool to MKL.
+  int prev_num_threads = mkl_set_num_threads_local(
+      run_options->intra_op_thread_pool()->numThreads());
+  MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+  // Set thread number back to the previous number.
+  mkl_set_num_threads_local(prev_num_threads);
+}
+void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
+                                                  float* out, float* lhs,
+                                                  float* rhs, int64 m, int64 n,
+                                                  int64 k, int32 transpose_lhs,
+                                                  int32 transpose_rhs) {
+  // Set the thread number to 1 for single threaded excution.
+  int prev_num_threads = mkl_set_num_threads_local(1);
+  MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+  // Set thread number back to the previous number.
+  mkl_set_num_threads_local(prev_num_threads);
+}
+void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
+                                                  double* out, double* lhs,
+                                                  double* rhs, int64 m, int64 n,
+                                                  int64 k, int32 transpose_lhs,
+                                                  int32 transpose_rhs) {
+  // Set the thread number to 1 for single threaded excution.
+  int prev_num_threads = mkl_set_num_threads_local(1);
+  MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+  // Set thread number back to the previous number.
+  mkl_set_num_threads_local(prev_num_threads);
+}
+#endif  // INTEL_MKL
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h
new file mode 100644 (file)
index 0000000..831b796
--- /dev/null
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
+
+#include <iostream>
+#include "tensorflow/core/platform/types.h"
+#ifdef INTEL_MKL
+#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
+
+extern void __xla_cpu_runtime_MKLMatMulF32(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs);
+extern void __xla_cpu_runtime_MKLMatMulF64(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs);
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs);
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs);
+
+#else
+extern void __xla_cpu_runtime_MKLMatMulF32(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs) {
+  std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+               "INTEL_MKL. Add --config=mkl to build with MKL.";
+  exit(1);
+}
+extern void __xla_cpu_runtime_MKLMatMulF64(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs) {
+  std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+               "INTEL_MKL. Add --config=mkl to build with MKL.";
+  exit(1);
+}
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs) {
+  std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+               "INTEL_MKL. Add --config=mkl to build with MKL.";
+  exit(1);
+}
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+    tensorflow::int32 transpose_rhs) {
+  std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+               "INTEL_MKL. Add --config=mkl to build with MKL.";
+  exit(1);
+}
+
+#endif  // INTEL_MKL
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
index 4198260..b7ce5bb 100644 (file)
@@ -35,6 +35,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
@@ -183,6 +184,10 @@ bool RegisterKnownJITSymbols() {
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
+  REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
+  REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
+  REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
+  REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
index 5cb1811..f9943f7 100644 (file)
@@ -189,6 +189,9 @@ message DebugOptions {
   // directory.
   string xla_dump_per_pass_hlo_proto_to = 96;
 
+  // Generate calls to MKL-DNN in the CPU backend.
+  bool xla_cpu_use_mkl_dnn = 97;
+
   // Extra options to pass to the compilation backend; specific interpretation
   // of these values is left to the backend.
   map<string, string> xla_backend_extra_options = 500;