From: Tony Wang Date: Wed, 4 Apr 2018 17:00:39 +0000 (-0700) Subject: Automated g4 rollback of changelist 191527251 X-Git-Tag: tflite-v0.1.7~39^2^2~39 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f4dcfcaae4e85bbb727eb1f5bfc14f6fa3a055ed;p=platform%2Fupstream%2Ftensorflow.git Automated g4 rollback of changelist 191527251 PiperOrigin-RevId: 191605505 --- diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f037663..c8ed3e3 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -40,9 +40,6 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { flags->set_xla_cpu_multi_thread_eigen(true); flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); flags->set_xla_eliminate_hlo_implicit_broadcast(true); -#ifdef INTEL_MKL - flags->set_xla_cpu_use_mkl_dnn(true); -#endif // INTEL_MKL // Set cudnn batchnorm off by default; it does not provide a performance win // on average. @@ -291,10 +288,6 @@ void AllocateFlags() { flag_values->xla_gpu_use_cudnn_batchnorm(), "Allows the GPU backend to implement batchnorm HLOs using cudnn, " "rather than expanding them to a soup of HLOs."), - tensorflow::Flag("xla_cpu_use_mkl_dnn", - bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), - flag_values->xla_cpu_use_mkl_dnn(), - "Generate calls to MKL-DNN in the CPU backend."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d22c135..966e2d0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -18,10 +18,6 @@ load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) # Filegroup used to collect source files for dependency checking. filegroup( @@ -174,7 +170,6 @@ cc_library( ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", - ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", @@ -539,29 +534,11 @@ cc_library( ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//third_party/eigen3", ], ) cc_library( - name = "runtime_matmul_mkl", - srcs = ["runtime_matmul_mkl.cc"], - hdrs = ["runtime_matmul_mkl.h"], - copts = runtime_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", - "//third_party/eigen3", - ] + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), -) - -cc_library( name = "runtime_single_threaded_conv2d", srcs = [ "runtime_conv2d_impl.h", @@ -607,12 +584,10 @@ cc_library( tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], - shard_count = 10, tags = ["optonly"], deps = [ ":cpu_runtime", ":runtime_matmul", - ":runtime_matmul_mkl", ":runtime_single_threaded_matmul", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 872b0be..9a3bd68 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -37,14 +37,6 @@ extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; -extern const char* const kMKLMatMulF32SymbolName = - "__xla_cpu_runtime_MKLMatMulF32"; -extern const char* const kMKLMatMulF64SymbolName = - "__xla_cpu_runtime_MKLMatMulF64"; -extern const char* const kMKLSingleThreadedMatMulF32SymbolName = - "__xla_cpu_runtime_MKLSingleThreadedMatMulF32"; -extern const char* const kMKLSingleThreadedMatMulF64SymbolName = - "__xla_cpu_runtime_MKLSingleThreadedMatMulF64"; extern const char* const kEigenConvF16SymbolName = "__xla_cpu_runtime_EigenConvF16"; extern const char* const kEigenConvF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index e392e23..e61d6ea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -44,10 +44,6 @@ namespace runtime { extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; -extern const char* const kMKLMatMulF32SymbolName; -extern const char* const kMKLMatMulF64SymbolName; -extern const char* const kMKLSingleThreadedMatMulF32SymbolName; -extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 9e04307..f385829 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" @@ -131,19 +130,21 @@ MatMulShape MatMulShapes[] = { // * transpose_lhs // * transpose_rhs // * single_threaded -using MatMulTestParam = std::tuple; +using EigenMatMulTestParam = std::tuple; -class EigenMatMulTest : public CpuRuntimeTest, - public ::testing::WithParamInterface { +class EigenMatMulTest + : public CpuRuntimeTest, + public ::testing::WithParamInterface { public: - static string Name(const ::testing::TestParamInfo& info) { + static string Name( + const ::testing::TestParamInfo& info) { MatMulShape shape = std::get<0>(info.param); bool transpose_lhs = std::get<1>(info.param); bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); return tensorflow::strings::Printf( - "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, + "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", single_threaded ? "single" : "multi"); } @@ -168,74 +169,5 @@ INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, ::testing::Bool()), EigenMatMulTest::Name); -#ifdef INTEL_MKL -class MKLMatMulTest : public CpuRuntimeTest, - public ::testing::WithParamInterface { - public: - static string Name(const ::testing::TestParamInfo& info) { - MatMulShape shape = std::get<0>(info.param); - bool transpose_lhs = std::get<1>(info.param); - bool transpose_rhs = std::get<2>(info.param); - bool single_threaded = std::get<3>(info.param); - - return tensorflow::strings::Printf( - "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); - } -}; - -std::unique_ptr> MKLMatrixMultiply(const Array2D& a, - const Array2D& b, - bool transpose_lhs, - bool transpose_rhs, - bool single_threaded) { - CHECK_EQ(a.width(), b.height()); - int64 m = a.height(); - int64 n = b.width(); - int64 k = a.width(); - - // The MKL matmul runtime function expects the matrix to be in column major - // order and array2d is in row-major order. Create transposes of a and b. The - // 'data' buffer in the transposed array is the original array in column major - // order. - auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs); - auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs); - - // Since we're going to transpose c before returning it, swap the order of the - // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); - if (single_threaded) { - __xla_cpu_runtime_MKLSingleThreadedMatMulF32( - nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), - m, n, k, transpose_lhs, transpose_rhs); - } else { - __xla_cpu_runtime_MKLMatMulF32(nullptr, c_transpose->data(), - a_transpose->data(), b_transpose->data(), m, - n, k, transpose_lhs, transpose_rhs); - } - return MaybeTransposeArray2D(*c_transpose, true); -} - -TEST_P(MKLMatMulTest, DoIt) { - MatMulShape shape = std::get<0>(GetParam()); - bool transpose_lhs = std::get<1>(GetParam()); - bool transpose_rhs = std::get<2>(GetParam()); - bool single_threaded = std::get<3>(GetParam()); - - auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k); - auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n); - auto c = - MKLMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs, single_threaded); - CheckMatrixMultiply(*a, *b, *c); -} - -INSTANTIATE_TEST_CASE_P(MKLMatMulTestInstantiaion, MKLMatMulTest, - ::testing::Combine(::testing::ValuesIn(MatMulShapes), - ::testing::Bool(), ::testing::Bool(), - ::testing::Bool()), - MKLMatMulTest::Name); -#endif // INTEL_MKL - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 29afd8e..8b1e20d 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -918,35 +918,28 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = + bool multi_threaded_eigen = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); - bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; const char* fn_name; switch (type) { case F16: - fn_name = multi_threaded + fn_name = multi_threaded_eigen ? runtime::kEigenMatMulF16SymbolName : runtime::kEigenSingleThreadedMatMulF16SymbolName; float_type = ir_builder_->getHalfTy(); break; case F32: - fn_name = multi_threaded - ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName - : runtime::kEigenMatMulF32SymbolName) - : (use_mkl_dnn - ? runtime::kMKLSingleThreadedMatMulF32SymbolName - : runtime::kEigenSingleThreadedMatMulF32SymbolName); + fn_name = multi_threaded_eigen + ? runtime::kEigenMatMulF32SymbolName + : runtime::kEigenSingleThreadedMatMulF32SymbolName; float_type = ir_builder_->getFloatTy(); break; case F64: - fn_name = multi_threaded - ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName - : runtime::kEigenMatMulF64SymbolName) - : (use_mkl_dnn - ? runtime::kMKLSingleThreadedMatMulF64SymbolName - : runtime::kEigenSingleThreadedMatMulF64SymbolName); + fn_name = multi_threaded_eigen + ? runtime::kEigenMatMulF64SymbolName + : runtime::kEigenSingleThreadedMatMulF64SymbolName; float_type = ir_builder_->getDoubleTy(); break; default: diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc deleted file mode 100644 index 729a4e7..0000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef INTEL_MKL -#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" -#include "third_party/intel_mkl_ml/include/mkl_cblas.h" -#include "third_party/intel_mkl_ml/include/mkl_service.h" - -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" - -#define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" - -using tensorflow::int32; -using tensorflow::int64; - -namespace { -// BLAS GEMM API for 32-bit Matrix Multiplication. - -// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c. -// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0. -// Matrix lhs, rhs and out are all colum-major. -void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs, - int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - const float alpha = 1.0f, beta = 0.0f; - // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c, - // respectively. For column-major matrices, the leading dimension is the - // stride between consecutive columns (which equals the number of rows). If - // the matrix is transposed, the leading dimension is the stride between - // consecutive rows (which equals the number of columns). - int lda = transpose_lhs ? k : m; - int ldb = transpose_rhs ? n : k; - int ldc = m; - cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans, - transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs, - lda, rhs, ldb, beta, out, ldc); -} - -// BLAS GEMM API for 64-bit Matrix Multiplication. - -// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c. -// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0. -// Matrix lhs, rhs and out are all colum-major. -void MatMulF64(const void* run_options_ptr, double* out, double* lhs, - double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - const float alpha = 1.0f, beta = 0.0f; - // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c, - // respectively. For a column-major matrix, the leading dimension is the - // stride between consecutive columns (which equals the number of rows). If - // the matrix is transposed, the leading dimension is the stride between - // consecutive rows (which equals the number of columns). - int lda = transpose_lhs ? k : m; - int ldb = transpose_rhs ? n : k; - int ldc = m; - cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans, - transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs, - lda, rhs, ldb, beta, out, ldc); -} - -} // namespace - -void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - const xla::ExecutableRunOptions* run_options = - static_cast(run_options_ptr); - // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread - // number specified in intra_op_thread_pool to MKL. - int prev_num_threads = mkl_set_num_threads_local( - run_options->intra_op_thread_pool()->numThreads()); - MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - // Set thread number back to the previous number. - mkl_set_num_threads_local(prev_num_threads); -} -// BLAS GEMM API for 64-bit Matrix Multiplication -void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - const xla::ExecutableRunOptions* run_options = - static_cast(run_options_ptr); - // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread - // number specified in intra_op_thread_pool to MKL. - int prev_num_threads = mkl_set_num_threads_local( - run_options->intra_op_thread_pool()->numThreads()); - MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - // Set thread number back to the previous number. - mkl_set_num_threads_local(prev_num_threads); -} -void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - // Set the thread number to 1 for single threaded excution. - int prev_num_threads = mkl_set_num_threads_local(1); - MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - // Set thread number back to the previous number. - mkl_set_num_threads_local(prev_num_threads); -} -void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - // Set the thread number to 1 for single threaded excution. - int prev_num_threads = mkl_set_num_threads_local(1); - MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - // Set thread number back to the previous number. - mkl_set_num_threads_local(prev_num_threads); -} -#endif // INTEL_MKL diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h deleted file mode 100644 index 9dbc506..0000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#ifdef INTEL_MKL -#include "third_party/intel_mkl_ml/include/mkl_cblas.h" - -extern void __xla_cpu_runtime_MKLMatMulF32( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, - float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); -extern void __xla_cpu_runtime_MKLMatMulF64( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, - double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); -extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, - float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); -extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, - double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); - -#else -extern void __xla_cpu_runtime_MKLMatMulF32( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, - float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs) { - LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining " - "INTEL_MKL. Add --config=mkl to build with MKL."; -} -extern void __xla_cpu_runtime_MKLMatMulF64( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, - double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs) { - LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining " - "INTEL_MKL. Add --config=mkl to build with MKL."; -} -extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, - float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs) { - LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining " - "INTEL_MKL. Add --config=mkl to build with MKL."; -} -extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64( - const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, - double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, - tensorflow::int64 k, tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs) { - LOG(FATAL) << "Attempt to call MKL MatMul runtime library without defining " - "INTEL_MKL. Add --config=mkl to build with MKL."; -} - -#endif // INTEL_MKL -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b7ce5bb..4198260 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" @@ -184,10 +183,6 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); - REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32); - REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64); - REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32); - REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f9943f7..5cb1811 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -189,9 +189,6 @@ message DebugOptions { // directory. string xla_dump_per_pass_hlo_proto_to = 96; - // Generate calls to MKL-DNN in the CPU backend. - bool xla_cpu_use_mkl_dnn = 97; - // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500;