Automated g4 rollback of changelist 191527251
authorTony Wang <tonywy@google.com>
Wed, 4 Apr 2018 17:00:39 +0000 (10:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 17:02:52 +0000 (10:02 -0700)
PiperOrigin-RevId: 191605505

tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
tensorflow/compiler/xla/service/cpu/cpu_runtime.h
tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc [deleted file]
tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h [deleted file]
tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
tensorflow/compiler/xla/xla.proto

index f037663..c8ed3e3 100644 (file)
@@ -40,9 +40,6 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
   flags->set_xla_cpu_multi_thread_eigen(true);
   flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
   flags->set_xla_eliminate_hlo_implicit_broadcast(true);
-#ifdef INTEL_MKL
-  flags->set_xla_cpu_use_mkl_dnn(true);
-#endif  // INTEL_MKL
 
   // Set cudnn batchnorm off by default; it does not provide a performance win
   // on average.
@@ -291,10 +288,6 @@ void AllocateFlags() {
           flag_values->xla_gpu_use_cudnn_batchnorm(),
           "Allows the GPU backend to implement batchnorm HLOs using cudnn, "
           "rather than expanding them to a soup of HLOs."),
-      tensorflow::Flag("xla_cpu_use_mkl_dnn",
-                       bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
-                       flag_values->xla_cpu_use_mkl_dnn(),
-                       "Generate calls to MKL-DNN in the CPU backend."),
   });
   ParseFlagsFromEnv(*flag_objects);
 }
index d22c135..966e2d0 100644 (file)
@@ -18,10 +18,6 @@ load(":build_defs.bzl", "runtime_copts")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
 load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
-load(
-    "//third_party/mkl:build_defs.bzl",
-    "if_mkl",
-)
 
 # Filegroup used to collect source files for dependency checking.
 filegroup(
@@ -174,7 +170,6 @@ cc_library(
         ":runtime_fft",
         ":runtime_fork_join",
         ":runtime_matmul",
-        ":runtime_matmul_mkl",
         ":runtime_single_threaded_conv2d",
         ":runtime_single_threaded_matmul",
         "@llvm//:execution_engine",
@@ -539,29 +534,11 @@ cc_library(
         ":runtime_matvec",
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/core:framework_lite",
-        "//tensorflow/core:lib",
         "//third_party/eigen3",
     ],
 )
 
 cc_library(
-    name = "runtime_matmul_mkl",
-    srcs = ["runtime_matmul_mkl.cc"],
-    hdrs = ["runtime_matmul_mkl.h"],
-    copts = runtime_copts(),
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow/compiler/xla:executable_run_options",
-        "//tensorflow/core:framework_lite",
-        "//tensorflow/core:lib",
-        "//third_party/eigen3",
-    ] + if_mkl([
-        "//third_party/mkl:intel_binary_blob",
-        "@mkl_dnn",
-    ]),
-)
-
-cc_library(
     name = "runtime_single_threaded_conv2d",
     srcs = [
         "runtime_conv2d_impl.h",
@@ -607,12 +584,10 @@ cc_library(
 tf_cc_test(
     name = "cpu_runtime_test",
     srcs = ["cpu_runtime_test.cc"],
-    shard_count = 10,
     tags = ["optonly"],
     deps = [
         ":cpu_runtime",
         ":runtime_matmul",
-        ":runtime_matmul_mkl",
         ":runtime_single_threaded_matmul",
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:types",
index 872b0be..9a3bd68 100644 (file)
@@ -37,14 +37,6 @@ extern const char* const kEigenMatMulF32SymbolName =
     "__xla_cpu_runtime_EigenMatMulF32";
 extern const char* const kEigenMatMulF64SymbolName =
     "__xla_cpu_runtime_EigenMatMulF64";
-extern const char* const kMKLMatMulF32SymbolName =
-    "__xla_cpu_runtime_MKLMatMulF32";
-extern const char* const kMKLMatMulF64SymbolName =
-    "__xla_cpu_runtime_MKLMatMulF64";
-extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
-    "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
-extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
-    "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
 extern const char* const kEigenConvF16SymbolName =
     "__xla_cpu_runtime_EigenConvF16";
 extern const char* const kEigenConvF32SymbolName =
index e392e23..e61d6ea 100644 (file)
@@ -44,10 +44,6 @@ namespace runtime {
 extern const char* const kEigenMatMulF16SymbolName;
 extern const char* const kEigenMatMulF32SymbolName;
 extern const char* const kEigenMatMulF64SymbolName;
-extern const char* const kMKLMatMulF32SymbolName;
-extern const char* const kMKLMatMulF64SymbolName;
-extern const char* const kMKLSingleThreadedMatMulF32SymbolName;
-extern const char* const kMKLSingleThreadedMatMulF64SymbolName;
 extern const char* const kEigenConvF16SymbolName;
 extern const char* const kEigenConvF32SymbolName;
 extern const char* const kEigenFftSymbolName;
index 9e04307..f385829 100644 (file)
@@ -24,7 +24,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
-#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
@@ -131,19 +130,21 @@ MatMulShape MatMulShapes[] = {
 // * transpose_lhs
 // * transpose_rhs
 // * single_threaded
-using MatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
+using EigenMatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
 
-class EigenMatMulTest : public CpuRuntimeTest,
-                        public ::testing::WithParamInterface<MatMulTestParam> {
+class EigenMatMulTest
+    : public CpuRuntimeTest,
+      public ::testing::WithParamInterface<EigenMatMulTestParam> {
  public:
-  static string Name(const ::testing::TestParamInfo<MatMulTestParam>& info) {
+  static string Name(
+      const ::testing::TestParamInfo<EigenMatMulTestParam>& info) {
     MatMulShape shape = std::get<0>(info.param);
     bool transpose_lhs = std::get<1>(info.param);
     bool transpose_rhs = std::get<2>(info.param);
     bool single_threaded = std::get<3>(info.param);
 
     return tensorflow::strings::Printf(
-        "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
+        "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
         transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
         single_threaded ? "single" : "multi");
   }
@@ -168,74 +169,5 @@ INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest,
                                            ::testing::Bool()),
                         EigenMatMulTest::Name);
 
-#ifdef INTEL_MKL
-class MKLMatMulTest : public CpuRuntimeTest,
-                      public ::testing::WithParamInterface<MatMulTestParam> {
- public:
-  static string Name(const ::testing::TestParamInfo<MatMulTestParam>& info) {
-    MatMulShape shape = std::get<0>(info.param);
-    bool transpose_lhs = std::get<1>(info.param);
-    bool transpose_rhs = std::get<2>(info.param);
-    bool single_threaded = std::get<3>(info.param);
-
-    return tensorflow::strings::Printf(
-        "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
-        transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
-        single_threaded ? "single" : "multi");
-  }
-};
-
-std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a,
-                                                  const Array2D<float>& b,
-                                                  bool transpose_lhs,
-                                                  bool transpose_rhs,
-                                                  bool single_threaded) {
-  CHECK_EQ(a.width(), b.height());
-  int64 m = a.height();
-  int64 n = b.width();
-  int64 k = a.width();
-
-  // The MKL matmul runtime function expects the matrix to be in column major
-  // order and array2d is in row-major order. Create transposes of a and b. The
-  // 'data' buffer in the transposed array is the original array in column major
-  // order.
-  auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs);
-  auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs);
-
-  // Since we're going to transpose c before returning it, swap the order of the
-  // dimension sizes to ensure the returned array is properly dimensioned.
-  auto c_transpose = MakeUnique<Array2D<float>>(n, m);
-  if (single_threaded) {
-    __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
-        nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
-        m, n, k, transpose_lhs, transpose_rhs);
-  } else {
-    __xla_cpu_runtime_MKLMatMulF32(nullptr, c_transpose->data(),
-                                   a_transpose->data(), b_transpose->data(), m,
-                                   n, k, transpose_lhs, transpose_rhs);
-  }
-  return MaybeTransposeArray2D(*c_transpose, true);
-}
-
-TEST_P(MKLMatMulTest, DoIt) {
-  MatMulShape shape = std::get<0>(GetParam());
-  bool transpose_lhs = std::get<1>(GetParam());
-  bool transpose_rhs = std::get<2>(GetParam());
-  bool single_threaded = std::get<3>(GetParam());
-
-  auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k);
-  auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n);
-  auto c =
-      MKLMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs, single_threaded);
-  CheckMatrixMultiply(*a, *b, *c);
-}
-
-INSTANTIATE_TEST_CASE_P(MKLMatMulTestInstantiaion, MKLMatMulTest,
-                        ::testing::Combine(::testing::ValuesIn(MatMulShapes),
-                                           ::testing::Bool(), ::testing::Bool(),
-                                           ::testing::Bool()),
-                        MKLMatMulTest::Name);
-#endif  // INTEL_MKL
-
 }  // namespace
 }  // namespace xla
index 29afd8e..8b1e20d 100644 (file)
@@ -918,35 +918,28 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
   // The two transpose_... parameters are actually booleans, but we use int32
   // to avoid target-dependent calling convention details.
 
-  bool multi_threaded =
+  bool multi_threaded_eigen =
       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
-  bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
   PrimitiveType type = target_array_.GetShape().element_type();
   llvm::Type* float_type;
   const char* fn_name;
   switch (type) {
     case F16:
-      fn_name = multi_threaded
+      fn_name = multi_threaded_eigen
                     ? runtime::kEigenMatMulF16SymbolName
                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
       float_type = ir_builder_->getHalfTy();
       break;
     case F32:
-      fn_name = multi_threaded
-                    ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
-                                   : runtime::kEigenMatMulF32SymbolName)
-                    : (use_mkl_dnn
-                           ? runtime::kMKLSingleThreadedMatMulF32SymbolName
-                           : runtime::kEigenSingleThreadedMatMulF32SymbolName);
+      fn_name = multi_threaded_eigen
+                    ? runtime::kEigenMatMulF32SymbolName
+                    : runtime::kEigenSingleThreadedMatMulF32SymbolName;
       float_type = ir_builder_->getFloatTy();
       break;
     case F64:
-      fn_name = multi_threaded
-                    ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
-                                   : runtime::kEigenMatMulF64SymbolName)
-                    : (use_mkl_dnn
-                           ? runtime::kMKLSingleThreadedMatMulF64SymbolName
-                           : runtime::kEigenSingleThreadedMatMulF64SymbolName);
+      fn_name = multi_threaded_eigen
+                    ? runtime::kEigenMatMulF64SymbolName
+                    : runtime::kEigenSingleThreadedMatMulF64SymbolName;
       float_type = ir_builder_->getDoubleTy();
       break;
     default:
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
deleted file mode 100644 (file)
index 729a4e7..0000000
+++ /dev/null
@@ -1,129 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef INTEL_MKL
-#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
-#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
-#include "third_party/intel_mkl_ml/include/mkl_service.h"
-
-#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/types.h"
-
-#define EIGEN_USE_THREADS
-#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
-
-using tensorflow::int32;
-using tensorflow::int64;
-
-namespace {
-// BLAS GEMM API for 32-bit Matrix Multiplication.
-
-// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
-// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
-// Matrix lhs, rhs and out are all colum-major.
-void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs,
-               int64 m, int64 n, int64 k, int32 transpose_lhs,
-               int32 transpose_rhs) {
-  const float alpha = 1.0f, beta = 0.0f;
-  // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
-  // respectively. For column-major matrices, the leading dimension is the
-  // stride between consecutive columns (which equals the number of rows). If
-  // the matrix is transposed, the leading dimension is the stride between
-  // consecutive rows (which equals the number of columns).
-  int lda = transpose_lhs ? k : m;
-  int ldb = transpose_rhs ? n : k;
-  int ldc = m;
-  cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
-              transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
-              lda, rhs, ldb, beta, out, ldc);
-}
-
-// BLAS GEMM API for 64-bit Matrix Multiplication.
-
-// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
-// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
-// Matrix lhs, rhs and out are all colum-major.
-void MatMulF64(const void* run_options_ptr, double* out, double* lhs,
-               double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
-               int32 transpose_rhs) {
-  const float alpha = 1.0f, beta = 0.0f;
-  // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
-  // respectively. For a column-major matrix, the leading dimension is the
-  // stride between consecutive columns (which equals the number of rows). If
-  // the matrix is transposed, the leading dimension is the stride between
-  // consecutive rows (which equals the number of columns).
-  int lda = transpose_lhs ? k : m;
-  int ldb = transpose_rhs ? n : k;
-  int ldc = m;
-  cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
-              transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
-              lda, rhs, ldb, beta, out, ldc);
-}
-
-}  // namespace
-
-void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
-                                    float* lhs, float* rhs, int64 m, int64 n,
-                                    int64 k, int32 transpose_lhs,
-                                    int32 transpose_rhs) {
-  const xla::ExecutableRunOptions* run_options =
-      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
-  // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
-  // number specified in intra_op_thread_pool to MKL.
-  int prev_num_threads = mkl_set_num_threads_local(
-      run_options->intra_op_thread_pool()->numThreads());
-  MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
-  // Set thread number back to the previous number.
-  mkl_set_num_threads_local(prev_num_threads);
-}
-// BLAS GEMM API for 64-bit Matrix Multiplication
-void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
-                                    double* lhs, double* rhs, int64 m, int64 n,
-                                    int64 k, int32 transpose_lhs,
-                                    int32 transpose_rhs) {
-  const xla::ExecutableRunOptions* run_options =
-      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
-  // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
-  // number specified in intra_op_thread_pool to MKL.
-  int prev_num_threads = mkl_set_num_threads_local(
-      run_options->intra_op_thread_pool()->numThreads());
-  MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
-  // Set thread number back to the previous number.
-  mkl_set_num_threads_local(prev_num_threads);
-}
-void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
-                                                  float* out, float* lhs,
-                                                  float* rhs, int64 m, int64 n,
-                                                  int64 k, int32 transpose_lhs,
-                                                  int32 transpose_rhs) {
-  // Set the thread number to 1 for single threaded excution.
-  int prev_num_threads = mkl_set_num_threads_local(1);
-  MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
-  // Set thread number back to the previous number.
-  mkl_set_num_threads_local(prev_num_threads);
-}
-void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
-                                                  double* out, double* lhs,
-                                                  double* rhs, int64 m, int64 n,
-                                                  int64 k, int32 transpose_lhs,
-                                                  int32 transpose_rhs) {
-  // Set the thread number to 1 for single threaded excution.
-  int prev_num_threads = mkl_set_num_threads_local(1);
-  MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
-  // Set thread number back to the previous number.
-  mkl_set_num_threads_local(prev_num_threads);
-}
-#endif  // INTEL_MKL
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h
deleted file mode 100644 (file)
index 9dbc506..0000000
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/types.h"
-#ifdef INTEL_MKL
-#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
-
-extern void __xla_cpu_runtime_MKLMatMulF32(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
-    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs);
-extern void __xla_cpu_runtime_MKLMatMulF64(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
-    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs);
-extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
-    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs);
-extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
-    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs);
-
-#else
-extern void __xla_cpu_runtime_MKLMatMulF32(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
-    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs) {
-  LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining "
-                "INTEL_MKL. Add --config=mkl to build with MKL.";
-}
-extern void __xla_cpu_runtime_MKLMatMulF64(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
-    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs) {
-  LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining "
-                "INTEL_MKL. Add --config=mkl to build with MKL.";
-}
-extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
-    float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs) {
-  LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining "
-                "INTEL_MKL. Add --config=mkl to build with MKL.";
-}
-extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(
-    const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
-    double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
-    tensorflow::int64 k, tensorflow::int32 transpose_lhs,
-    tensorflow::int32 transpose_rhs) {
-  LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining "
-                "INTEL_MKL. Add --config=mkl to build with MKL.";
-}
-
-#endif  // INTEL_MKL
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
index b7ce5bb..4198260 100644 (file)
@@ -35,7 +35,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
-#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
@@ -184,10 +183,6 @@ bool RegisterKnownJITSymbols() {
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
-  REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
-  REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
-  REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
-  REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
index f9943f7..5cb1811 100644 (file)
@@ -189,9 +189,6 @@ message DebugOptions {
   // directory.
   string xla_dump_per_pass_hlo_proto_to = 96;
 
-  // Generate calls to MKL-DNN in the CPU backend.
-  bool xla_cpu_use_mkl_dnn = 97;
-
   // Extra options to pass to the compilation backend; specific interpretation
   // of these values is left to the backend.
   map<string, string> xla_backend_extra_options = 500;