[XLA:CPU] Implement vectorized Log in LLVM IR
authorSanjoy Das <sanjoy@google.com>
Tue, 13 Feb 2018 05:56:19 +0000 (21:56 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 05:59:58 +0000 (21:59 -0800)
This was the last vectorized intrinsic for which we had to call into
C++ so also remove the associated machinery.

PiperOrigin-RevId: 185482962

17 files changed:
tensorflow/compiler/aot/tfcompile.bzl
tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/compiler_functor.cc
tensorflow/compiler/xla/service/cpu/compiler_functor.h
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc [deleted file]
tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h [deleted file]
tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc [deleted file]
tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h [deleted file]
tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc [deleted file]
tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h [deleted file]
tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h
tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
tensorflow/compiler/xla/service/cpu/vector_support_library.cc
tensorflow/compiler/xla/service/cpu/vector_support_library.h
tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc

index eb3e632..9dff1be 100644 (file)
@@ -224,9 +224,6 @@ def tf_library(name, graph, config,
           # TODO(cwhipkey): only depend on kernel code that the model actually needed.
           "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
           "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
-          "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
-          "//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
-          "//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
           "//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
           "//tensorflow/compiler/xla/service/cpu:runtime_matmul",
           "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
index 1a91dd8..c13a0b1 100644 (file)
@@ -159,9 +159,6 @@ cc_library(
     deps = [
         ":compiler_functor",
         ":cpu_runtime",
-        ":cpu_runtime_avx",
-        ":cpu_runtime_neon",
-        ":cpu_runtime_sse4_1",
         ":custom_call_target_registry",
         ":disassembler",
         ":external_constant_pool",
@@ -408,9 +405,6 @@ cc_library(
     hdrs = ["compiler_functor.h"],
     deps = [
         ":cpu_runtime",
-        ":cpu_runtime_avx",
-        ":cpu_runtime_neon",
-        ":cpu_runtime_sse4_1",
         ":disassembler",
         ":llvm_ir_runtime",
         "//tensorflow/compiler/xla:statusor",
@@ -431,43 +425,6 @@ cc_library(
 )
 
 cc_library(
-    name = "cpu_runtime_sse4_1",
-    srcs = ["cpu_runtime_sse4_1.cc"],
-    hdrs = ["cpu_runtime_sse4_1.h"],
-    copts = ["-DEIGEN_AVOID_STL_ARRAY"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow/core:framework_lite",
-        "//third_party/eigen3",
-    ],
-)
-
-cc_library(
-    name = "cpu_runtime_avx",
-    srcs = ["cpu_runtime_avx.cc"],
-    hdrs = ["cpu_runtime_avx.h"],
-    copts = ["-DEIGEN_AVOID_STL_ARRAY"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow/core:framework_lite",
-        "//third_party/eigen3",
-    ],
-)
-
-cc_library(
-    name = "cpu_runtime_neon",
-    srcs = ["cpu_runtime_neon.cc"],
-    hdrs = ["cpu_runtime_neon.h"],
-    # runtime_copts() enables -mfpu=neon
-    copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(),
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow/core:framework_lite",
-        "//third_party/eigen3",
-    ],
-)
-
-cc_library(
     name = "cpu_runtime",
     srcs = [
         "cpu_runtime.cc",
index 2723661..ed290fc 100644 (file)
@@ -37,9 +37,6 @@ limitations under the License.
 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -50,15 +47,6 @@ limitations under the License.
 namespace xla {
 namespace cpu {
 
-/* static */ CompilerFunctor::VectorIntrinsics
-CompilerFunctor::AllIntrinsics() {
-  VectorIntrinsics intrinsics;
-  intrinsics.sse_intrinsics = true;
-  intrinsics.avx_intrinsics = true;
-  intrinsics.neon_intrinsics = true;
-  return intrinsics;
-}
-
 /* Create filtered versions of the LLVM Pass Managers to filter out some
 of the expensive passes.
 Profiling:
@@ -192,31 +180,8 @@ operator()(llvm::Module& module) const {
       std::move(object_file), std::move(memory_buffer));
 }
 
-namespace {
-// Returns the set of vectorized library functions supported for the target.
-std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
-    llvm::Triple::ArchType arch, llvm::StringRef feature_string,
-    CompilerFunctor::VectorIntrinsics const& available_intrinsics) {
-  std::vector<llvm::VecDesc> vector_functions;
-
-  const llvm::VecDesc four_wide_vector_functions_neon[] = {
-      {"logf", runtime::kLogV4F32NEONSymbolName, 4},
-      {"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4},
-  };
-
-  const llvm::VecDesc four_wide_vector_functions_sse[] = {
-      {"logf", runtime::kLogV4F32SSESymbolName, 4},
-      {"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4},
-  };
-
-  const llvm::VecDesc eight_wide_vector_functions_avx[] = {
-      {"logf", runtime::kLogV8F32AVXSymbolName, 8},
-      {"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8},
-  };
-
-  // These functions are generated by XLA as LLVM IR, so they're always
-  // available.
-  const llvm::VecDesc ir_vector_functions[] = {
+static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
+  std::vector<llvm::VecDesc> result = {
       {"tanhf", runtime::kTanhV4F32SymbolName, 4},
       {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4},
 
@@ -228,50 +193,15 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
 
       {"expf", runtime::kExpV8F32SymbolName, 8},
       {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
-  };
 
-  llvm::SmallVector<llvm::StringRef, 32> features;
-  feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
-  auto has_feature = [&features](const llvm::StringRef feature) {
-    return std::find(features.begin(), features.end(), feature) !=
-           features.end();
-  };
-
-  switch (arch) {
-    case llvm::Triple::x86:
-    case llvm::Triple::x86_64: {
-      if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) {
-        vector_functions.insert(vector_functions.end(),
-                                std::begin(four_wide_vector_functions_sse),
-                                std::end(four_wide_vector_functions_sse));
-      }
-      if (has_feature("+avx") && available_intrinsics.avx_intrinsics) {
-        vector_functions.insert(vector_functions.end(),
-                                std::begin(eight_wide_vector_functions_avx),
-                                std::end(eight_wide_vector_functions_avx));
-      }
-      break;
-    }
-    case llvm::Triple::arm:
-    case llvm::Triple::aarch64: {
-      if (has_feature("+neon") && available_intrinsics.neon_intrinsics) {
-        vector_functions.insert(vector_functions.end(),
-                                std::begin(four_wide_vector_functions_neon),
-                                std::end(four_wide_vector_functions_neon));
-      }
-      break;
-    }
-    default:
-      break;
-  }
+      {"logf", runtime::kLogV4F32SymbolName, 4},
+      {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
 
-  vector_functions.insert(vector_functions.end(),
-                          std::begin(ir_vector_functions),
-                          std::end(ir_vector_functions));
-
-  return vector_functions;
+      {"logf", runtime::kLogV8F32SymbolName, 8},
+      {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
+  };
+  return result;
 }
-}  // namespace
 
 void CompilerFunctor::AddTargetInfoPasses(
     llvm::legacy::PassManagerBase* passes) const {
@@ -279,9 +209,7 @@ void CompilerFunctor::AddTargetInfoPasses(
   auto target_library_info_impl =
       MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
   target_library_info_impl->addVectorizableFunctions(
-      VectorFunctionsForTargetLibraryInfoImpl(
-          target_triple.getArch(), target_machine_->getTargetFeatureString(),
-          available_intrinsics_));
+      VectorFunctionsForTargetLibraryInfoImpl());
   passes->add(
       new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
   passes->add(createTargetTransformInfoWrapperPass(
index 8cdd049..1a8283a 100644 (file)
@@ -31,21 +31,10 @@ namespace cpu {
 // Orc JIT compile layer.
 class CompilerFunctor {
  public:
-  // Describes the set of vector intrinsics available to the generated code.
-  struct VectorIntrinsics {
-    bool sse_intrinsics;
-    bool avx_intrinsics;
-    bool neon_intrinsics;
-  };
-
-  // Returns a VectorIntrinsics where all intrinsics are available.
-  static VectorIntrinsics AllIntrinsics();
-
   explicit CompilerFunctor(
       llvm::TargetMachine* target_machine, const Disassembler* disassembler,
       int opt_level, bool optimize_for_size, bool enable_fast_math,
       bool disable_expensive_passes,
-      const VectorIntrinsics& available_intrinsics,
       LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
       LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
       : target_machine_(target_machine),
@@ -54,7 +43,6 @@ class CompilerFunctor {
         optimize_for_size_(optimize_for_size),
         enable_fast_math_(enable_fast_math),
         disable_expensive_passes_(disable_expensive_passes),
-        available_intrinsics_(available_intrinsics),
         pre_optimization_hook_(pre_optimization_hook),
         post_optimization_hook_(post_optimization_hook) {}
 
@@ -78,7 +66,6 @@ class CompilerFunctor {
   const bool optimize_for_size_;
   const bool enable_fast_math_;
   const bool disable_expensive_passes_;
-  const VectorIntrinsics available_intrinsics_;
   LLVMCompiler::ModuleHook pre_optimization_hook_;
   LLVMCompiler::ModuleHook post_optimization_hook_;
 };
index d13a97b..f9cc965 100644 (file)
@@ -888,8 +888,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
         options::OptimizeForSizeRequested(module->config()),
         module->config().debug_options().xla_enable_fast_math(),
         module->config().debug_options().xla_llvm_disable_expensive_passes(),
-        CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
-        post_optimization_ir_dump_hook);
+        pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook);
     llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
         compiler_functor(llvm_module);
     llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc
deleted file mode 100644 (file)
index 62bb87f..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
-
-#define EIGEN_USE_THREADS
-
-#include "third_party/eigen3/Eigen/Core"
-
-#ifdef TF_XLA_HAS_AVX
-xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
-    xla::cpu::runtime::V8F32AVX x) {
-  return Eigen::internal::plog(x);
-}
-#endif  // TF_XLA_HAS_AVX
-
-namespace xla {
-namespace cpu {
-namespace runtime {
-
-const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX";
-
-}  // namespace runtime
-}  // namespace cpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h
deleted file mode 100644 (file)
index f473c68..0000000
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This header declares functions which may be called by the generated code on
-// the CPU. Calls to these functions must be resolved explicitly in the JIT in
-// xla::cpu::SimpleResolver.  It also defines a per-CpuExecutable context
-// which is used to cache expensive state and resources utilized by the
-// aforementioned functions.
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
-
-#include "tensorflow/core/platform/macros.h"
-
-#if defined(__AVX__)
-#include <immintrin.h>
-#define TF_XLA_HAS_AVX
-#endif
-
-namespace xla {
-namespace cpu {
-namespace runtime {
-
-extern const char *const kLogV8F32AVXSymbolName;
-
-#ifdef TF_XLA_HAS_AVX
-typedef __m256 V8F32AVX;
-#endif
-}  // namespace runtime
-}  // namespace cpu
-}  // namespace xla
-
-extern "C" {
-
-#ifdef TF_XLA_HAS_AVX
-// The following functions are vectorized versions of a selection of libm
-// library functions.
-// References to these functions are created by the LLVM vectorizer.
-xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
-    xla::cpu::runtime::V8F32AVX x);
-
-xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
-    xla::cpu::runtime::V8F32AVX x);
-#endif
-}
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc
deleted file mode 100644 (file)
index 8099b72..0000000
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
-
-#define EIGEN_USE_THREADS
-
-#include "third_party/eigen3/Eigen/Core"
-
-#ifdef TF_XLA_HAS_NEON
-
-xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
-    xla::cpu::runtime::V4F32NEON x) {
-  return Eigen::internal::pexp(x);
-}
-
-xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
-    xla::cpu::runtime::V4F32NEON x) {
-  Eigen::internal::Packet4f p = x;
-  return Eigen::internal::plog(p);
-}
-
-#endif  // TF_XLA_HAS_NEON
-
-namespace xla {
-namespace cpu {
-namespace runtime {
-
-const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON";
-const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON";
-
-}  // namespace runtime
-}  // namespace cpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h
deleted file mode 100644 (file)
index 2f5d1a8..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
-
-// This header declares functions which may be called by the generated code on
-// the CPU. Calls to these functions must be resolved explicitly in the JIT in
-// xla::cpu::SimpleResolver.
-
-#include "tensorflow/core/platform/macros.h"
-
-#ifdef __ARM_NEON__
-// For the other runtimes (AVX, SSE4.1) we define the vector type directly using
-// __attribute__((__vector_size__(*))).  Unfortunately, the typedef for the ARM
-// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
-#include <arm_neon.h>
-#define TF_XLA_HAS_NEON
-#endif  // __ARM_NEON__
-
-namespace xla {
-namespace cpu {
-namespace runtime {
-
-extern const char *const kExpV4F32NEONSymbolName;
-extern const char *const kLogV4F32NEONSymbolName;
-
-#ifdef TF_XLA_HAS_NEON
-typedef float32x4_t V4F32NEON;
-#endif  // TF_XLA_HAS_NEON
-
-}  // namespace runtime
-}  // namespace cpu
-}  // namespace xla
-
-extern "C" {
-
-#ifdef TF_XLA_HAS_NEON
-// The following functions are vectorized versions of a selection of libm
-// library functions.
-// References to these functions are created by the LLVM vectorizer.
-xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
-    xla::cpu::runtime::V4F32NEON x);
-
-xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
-    xla::cpu::runtime::V4F32NEON x);
-#endif  // TF_XLA_HAS_NEON
-}
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc
deleted file mode 100644 (file)
index 1d5b5c2..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
-
-#define EIGEN_USE_THREADS
-
-#include "third_party/eigen3/Eigen/Core"
-
-#ifdef TF_XLA_HAS_SSE4_1
-
-xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
-    xla::cpu::runtime::V4F32SSE x) {
-  Eigen::internal::Packet4f p = x;
-  return Eigen::internal::plog(p);
-}
-
-#endif  // TF_XLA_HAS_SSE4_1
-
-namespace xla {
-namespace cpu {
-namespace runtime {
-
-const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE";
-
-}  // namespace runtime
-}  // namespace cpu
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h
deleted file mode 100644 (file)
index 3b3d181..0000000
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This header declares functions which may be called by the generated code on
-// the CPU. Calls to these functions must be resolved explicitly in the JIT in
-// xla::cpu::SimpleResolver.  It also defines a per-CpuExecutable context
-// which is used to cache expensive state and resources utilized by the
-// aforementioned functions.
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
-
-#include "tensorflow/core/platform/macros.h"
-
-// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1
-// when __AVX__ is defined, we should do the same.
-#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
-#include <smmintrin.h>
-#define TF_XLA_HAS_SSE4_1
-#endif
-
-namespace xla {
-namespace cpu {
-namespace runtime {
-
-extern const char *const kLogV4F32SSESymbolName;
-
-#ifdef TF_XLA_HAS_SSE4_1
-typedef __m128 V4F32SSE;
-#endif
-
-}  // namespace runtime
-}  // namespace cpu
-}  // namespace xla
-
-extern "C" {
-
-#ifdef TF_XLA_HAS_SSE4_1
-// The following functions are vectorized versions of a selection of libm
-// library functions.
-// References to these functions are created by the LLVM vectorizer.
-xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
-    xla::cpu::runtime::V4F32SSE x);
-#endif
-}
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
index 38fcd27..ee213e0 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include "llvm/IR/Verifier.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
+#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace xla {
@@ -31,6 +32,8 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
+const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
+const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
 
 namespace {
 llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
@@ -116,19 +119,19 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
 
   // This implements the same polynomial approximation as implemented in Eigen3.
 
-  const double exp_hi = 88.3762626647950;
-  const double exp_lo = -88.3762626647949;
+  const float exp_hi = 88.3762626647950;
+  const float exp_lo = -88.3762626647949;
 
-  const double cephes_LOG2EF = 1.44269504088896341;
-  const double cephes_exp_C1 = 0.693359375;
-  const double cephes_exp_C2 = -2.12194440e-4;
+  const float cephes_LOG2EF = 1.44269504088896341;
+  const float cephes_exp_C1 = 0.693359375;
+  const float cephes_exp_C2 = -2.12194440e-4;
 
-  const double cephes_exp_p0 = 1.9875691500E-4;
-  const double cephes_exp_p1 = 1.3981999507E-3;
-  const double cephes_exp_p2 = 8.3334519073E-3;
-  const double cephes_exp_p3 = 4.1665795894E-2;
-  const double cephes_exp_p4 = 1.6666665459E-1;
-  const double cephes_exp_p5 = 5.0000001201E-1;
+  const float cephes_exp_p0 = 1.9875691500E-4;
+  const float cephes_exp_p1 = 1.3981999507E-3;
+  const float cephes_exp_p2 = 8.3334519073E-3;
+  const float cephes_exp_p3 = 4.1665795894E-2;
+  const float cephes_exp_p4 = 1.6666665459E-1;
+  const float cephes_exp_p5 = 5.0000001201E-1;
 
   llvm::Value* input = &*vector_exp_function->arg_begin();
   llvm::Value* input_clamped =
@@ -146,7 +149,7 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
   y = vsl.MulAdd(y, x, cephes_exp_p4);
   y = vsl.MulAdd(y, x, cephes_exp_p5);
   y = vsl.MulAdd(y, z, x);
-  y = vsl.Add(1.0, y);
+  y = vsl.Add(1.0f, y);
 
   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
   // time so drop down to IRBuilder for this bit.
@@ -167,9 +170,133 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
 
   ir_builder.CreateRet(result);
 
-  CHECK(!llvm::verifyFunction(*vector_exp_function));
+  DCHECK(!llvm::verifyFunction(*vector_exp_function));
   return vector_exp_function;
 }
+
+llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
+                                         llvm::StringRef function_name,
+                                         int vector_width,
+                                         bool enable_fast_math) {
+  llvm::Function* vector_log_function = module->getFunction(function_name);
+  if (vector_log_function == nullptr) {
+    // If the function declaration is not present in the module, there can't be
+    // any calls to resolve.  Don't emit the function in this case.
+    return nullptr;
+  }
+
+  llvm::LLVMContext* context = &module->getContext();
+
+  llvm::BasicBlock* vector_log_body =
+      llvm::BasicBlock::Create(*context, "body", vector_log_function);
+
+  llvm::IRBuilder<> ir_builder(vector_log_body);
+  llvm::FastMathFlags fast_math_flags;
+  fast_math_flags.setFast();
+  ir_builder.setFastMathFlags(fast_math_flags);
+
+  llvm::Value* input = &*vector_log_function->arg_begin();
+  VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
+
+  const float half = 0.5;
+
+  // This implements the same polynomial approximation as implemented in Eigen3.
+  // Returns NaN for x < 0, -INF for x = 0
+  const float cephes_SQRTHF = 0.707106781186547524;
+  const float cephes_log_p0 = 7.0376836292E-2;
+  const float cephes_log_p1 = -1.1514610310E-1;
+  const float cephes_log_p2 = 1.1676998740E-1;
+  const float cephes_log_p3 = -1.2420140846E-1;
+  const float cephes_log_p4 = +1.4249322787E-1;
+  const float cephes_log_p5 = -1.6668057665E-1;
+  const float cephes_log_p6 = +2.0000714765E-1;
+  const float cephes_log_p7 = -2.4999993993E-1;
+  const float cephes_log_p8 = +3.3333331174E-1;
+  const float cephes_log_q1 = -2.12194440e-4;
+  const float cephes_log_q2 = 0.693359375;
+
+  // The smallest non denormalized float number.
+  const float min_norm_pos = tensorflow::bit_cast<float, int32>(0x00800000);
+  const float minus_inf = tensorflow::bit_cast<float, int32>(0xff800000);
+
+  // NB! This number is denormal and since TF sets the denormals-are-zero flag
+  // (and if TF didn't, -ffast-math would) trying to operate on this float using
+  // C++ operations (including, for instance, implicit conversion to double)
+  // will coerce this to zero.
+  const float inv_mant_mask = tensorflow::bit_cast<float, int32>(~0x7f800000);
+
+  // invalid_mask is set if x is negative or NaN (and therefore output
+  // must be NaN).
+  llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
+  llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
+
+  // Cut off denormalized stuff.
+  input = vsl.Max(min_norm_pos, input);
+
+  // VectorSupportLibrary (intentionally) can't juggle more than one type at a
+  // time so drop down to IRBuilder for this bit.
+  llvm::Value* vector_constant_0x7f =
+      ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
+  llvm::Value* vector_constant_23 =
+      ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
+  llvm::Type* i32_vector_type =
+      llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
+
+  llvm::Value* emm0 = ir_builder.CreateLShr(
+      ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23);
+
+  // Keep only the fractional part.
+  input = vsl.FloatAnd(input, inv_mant_mask);
+  input = vsl.FloatOr(input, half);
+
+  emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
+  llvm::Value* e =
+      vsl.Add(1.0f, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
+
+  // part2:
+  //   if( x < SQRTHF ) {
+  //     e -= 1;
+  //     x = x + x - 1.0;
+  //   } else { x = x - 1.0; }
+  llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
+  llvm::Value* tmp = vsl.FloatAnd(input, mask);
+  input = vsl.Sub(input, 1.0);
+  e = vsl.Sub(e, vsl.FloatAnd(mask, 1.0));
+  input = vsl.Add(input, tmp);
+
+  llvm::Value* x2 = vsl.Mul(input, input);
+  llvm::Value* x3 = vsl.Mul(x2, input);
+
+  llvm::Value *y, *y1, *y2;
+  y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1);
+  y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4);
+  y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7);
+  y = vsl.MulAdd(y, input, cephes_log_p2);
+  y1 = vsl.MulAdd(y1, input, cephes_log_p5);
+  y2 = vsl.MulAdd(y2, input, cephes_log_p8);
+  y = vsl.MulAdd(y, x3, y1);
+  y = vsl.MulAdd(y, x3, y2);
+  y = vsl.Mul(y, x3);
+
+  y1 = vsl.Mul(cephes_log_q1, e);
+  tmp = vsl.Mul(half, x2);
+  y = vsl.Add(y, y1);
+  input = vsl.Sub(input, tmp);
+  y2 = vsl.Mul(cephes_log_q2, e);
+  input = vsl.Add(input, y);
+  input = vsl.Add(input, y2);
+
+  // Negative arg will be NAN, 0 will be -INF.
+  llvm::Value* or_lhs =
+      vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask));
+  llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf);
+  llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs);
+
+  ir_builder.CreateRet(result);
+
+  DCHECK(!llvm::verifyFunction(*vector_log_function));
+  return vector_log_function;
+}
 }  // namespace
 
 void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
@@ -187,11 +314,21 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
       EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName,
                                /*vector_width=*/8, enable_fast_math);
 
+  auto* log_v4f32 =
+      EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName,
+                               /*vector_width=*/4, enable_fast_math);
+  auto* log_v8f32 =
+      EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName,
+                               /*vector_width=*/8, enable_fast_math);
+
   // Gather all the call sites, force inline them and then delete the vector
   // function bodies.
+  //
+  // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases?
 
   std::vector<llvm::CallInst*> calls_to_inline;
-  for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) {
+  for (auto* function :
+       {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
     if (function != nullptr) {
       for (auto* user : function->users()) {
         calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
@@ -204,7 +341,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
     CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
   }
 
-  for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) {
+  for (auto* function :
+       {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
     if (function != nullptr) {
       function->eraseFromParent();
     }
index 90050c4..5553972 100644 (file)
@@ -27,6 +27,8 @@ extern const char* const kTanhV4F32SymbolName;
 extern const char* const kTanhV8F32SymbolName;
 extern const char* const kExpV4F32SymbolName;
 extern const char* const kExpV8F32SymbolName;
+extern const char* const kLogV4F32SymbolName;
+extern const char* const kLogV8F32SymbolName;
 
 // The following CPU runtime functions have LLVM-IR only implementations:
 //
index 2f4468c..64d3a51 100644 (file)
@@ -28,9 +28,6 @@ limitations under the License.
 #include "llvm/Support/Host.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
@@ -101,27 +98,6 @@ llvm::StringRef GetHostCpuName() {
   cpu_name.consume_back("-avx512");
   return cpu_name;
 }
-
-CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
-  CompilerFunctor::VectorIntrinsics intrinsics;
-#ifdef TF_XLA_HAS_SSE4_1
-  intrinsics.sse_intrinsics = true;
-#else
-  intrinsics.sse_intrinsics = false;
-#endif
-#ifdef TF_XLA_HAS_AVX
-  intrinsics.avx_intrinsics = true;
-#else
-  intrinsics.avx_intrinsics = false;
-#endif
-#ifdef TF_XLA_HAS_NEON
-  intrinsics.neon_intrinsics = true;
-#else
-  intrinsics.neon_intrinsics = false;
-#endif
-  return intrinsics;
-}
-
 }  // namespace
 
 SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
@@ -169,13 +145,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
                 orc_jit_memory_mapper::GetInstance());
           },
           [this](llvm::orc::VModuleKey K) { return symbol_resolver_; }),
-      compile_layer_(
-          object_layer_,
-          CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
-                          optimize_for_size, enable_fast_math,
-                          disable_expensive_passes, GetAvailableIntrinsics(),
-                          std::move(pre_optimization_hook),
-                          std::move(post_optimization_hook))) {
+      compile_layer_(object_layer_,
+                     CompilerFunctor(target_machine_.get(), &disassembler_,
+                                     opt_level, optimize_for_size,
+                                     enable_fast_math, disable_expensive_passes,
+                                     std::move(pre_optimization_hook),
+                                     std::move(post_optimization_hook))) {
   VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
           << " features: " << target_machine_->getTargetFeatureString().str();
 }
@@ -240,15 +215,6 @@ bool RegisterKnownJITSymbols() {
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
-#ifdef TF_XLA_HAS_NEON
-  REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON);
-#endif
-#ifdef TF_XLA_HAS_SSE4_1
-  REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE);
-#endif
-#ifdef TF_XLA_HAS_AVX
-  REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX);
-#endif
   REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
index ec4215b..0596e80 100644 (file)
@@ -103,15 +103,92 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
   }
 }
 
-llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, double low,
-                                         double high) {
+llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, float low,
+                                         float high) {
   AssertCorrectTypes({a});
   llvm::Type* type = a->getType();
   CHECK_LT(low, high);
   CHECK(scalar_type_->isFloatingPointTy());
   return llvm_ir::EmitFloatMin(
-      llvm_ir::EmitFloatMax(a, llvm::ConstantFP::get(type, low), ir_builder_),
-      llvm::ConstantFP::get(type, high), ir_builder_);
+      llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
+      GetConstantFloat(type, high), ir_builder_);
+}
+
+llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
+                                              llvm::Value* rhs) {
+  AssertCorrectTypes({lhs, rhs});
+  return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name()));
+}
+
+llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
+                                               llvm::Value* rhs) {
+  AssertCorrectTypes({lhs, rhs});
+  return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name()));
+}
+
+llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
+                                               llvm::Value* rhs) {
+  AssertCorrectTypes({lhs, rhs});
+  return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name()));
+}
+
+llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
+  bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
+  llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
+  return ir_builder()->CreateBitCast(
+      ir_builder()->CreateSExt(i1, integer_type, name()),
+      is_vector ? vector_type() : scalar_type(), name());
+}
+
+llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
+  CHECK(scalar_type()->isFloatingPointTy());
+  const llvm::DataLayout& data_layout =
+      ir_builder()->GetInsertBlock()->getModule()->getDataLayout();
+  int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
+  llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits);
+  if (vector) {
+    return llvm::VectorType::get(scalar_int_type, vector_size());
+  } else {
+    return scalar_int_type;
+  }
+}
+
+llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
+  CHECK_EQ(x->getType(), scalar_type());
+  return ir_builder()->CreateVectorSplat(vector_size(), x, name());
+}
+
+llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
+                                            llvm::Value* rhs) {
+  AssertCorrectTypes({lhs, rhs});
+  llvm::Type* int_type =
+      IntegerTypeForFloatSize(lhs->getType() == vector_type());
+  return ir_builder()->CreateBitCast(
+      ir_builder()->CreateAnd(
+          ir_builder()->CreateBitCast(lhs, int_type, name()),
+          ir_builder()->CreateBitCast(rhs, int_type, name()), name()),
+      vector_type());
+}
+
+llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
+  AssertCorrectTypes({lhs});
+  llvm::Type* int_type =
+      IntegerTypeForFloatSize(lhs->getType() == vector_type());
+  return ir_builder()->CreateBitCast(
+      ir_builder()->CreateNot(
+          ir_builder()->CreateBitCast(lhs, int_type, name()), name()),
+      vector_type());
+}
+
+llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
+  AssertCorrectTypes({lhs, rhs});
+  llvm::Type* int_type =
+      IntegerTypeForFloatSize(lhs->getType() == vector_type());
+  return ir_builder()->CreateBitCast(
+      ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()),
+                             ir_builder()->CreateBitCast(rhs, int_type, name()),
+                             name()),
+      vector_type(), name());
 }
 
 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
index 5c5d703..010c82f 100644 (file)
@@ -41,40 +41,82 @@ class VectorSupportLibrary {
   llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
     return Mul(ir_builder()->getInt64(lhs), rhs);
   }
-  llvm::Value* Mul(double lhs, llvm::Value* rhs) {
-    return Mul(llvm::ConstantFP::get(rhs->getType(), lhs), rhs);
+  llvm::Value* Mul(float lhs, llvm::Value* rhs) {
+    return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
   }
 
   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
   llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
     return Add(ir_builder()->getInt64(lhs), rhs);
   }
-  llvm::Value* Add(double lhs, llvm::Value* rhs) {
-    return Add(llvm::ConstantFP::get(vector_type(), lhs), rhs);
+  llvm::Value* Add(float lhs, llvm::Value* rhs) {
+    return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
   }
 
   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* Sub(llvm::Value* lhs, float rhs) {
+    return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
+  }
   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* Max(float lhs, llvm::Value* rhs) {
+    return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
+  }
   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
 
   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
     return Add(c, Mul(a, b));
   }
 
-  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, double c) {
-    return Add(llvm::ConstantFP::get(vector_type(), c), Mul(a, b));
+  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, float c) {
+    return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
   }
 
-  llvm::Value* MulAdd(llvm::Value* a, double b, double c) {
-    return Add(llvm::ConstantFP::get(a->getType(), c),
-               Mul(a, llvm::ConstantFP::get(a->getType(), b)));
+  llvm::Value* MulAdd(llvm::Value* a, float b, float c) {
+    return Add(GetConstantFloat(a->getType(), c),
+               Mul(a, GetConstantFloat(a->getType(), b)));
   }
 
   llvm::Value* Floor(llvm::Value* a);
 
-  llvm::Value* Clamp(llvm::Value* a, double low, double high);
-  llvm::Value* SplatFloat(double d) {
-    return llvm::ConstantFP::get(vector_type(), d);
+  llvm::Value* Clamp(llvm::Value* a, float low, float high);
+  llvm::Value* SplatFloat(float d) {
+    return GetConstantFloat(vector_type(), d);
+  }
+
+  // These compare instructions return a floating point typed mask instead of an
+  // i1.  For instance, on a vector typed input, lanes where the predicate is
+  // true get a float with all ones and other lanes get a float with all zeros.
+  // This is slightly odd from the perspective of LLVM's type system, but it
+  // makes kernel IR generation code written using VectorSupportLibrary (its
+  // raison d'etre) less cluttered.
+
+  llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* FCmpOLTMask(llvm::Value* lhs, float rhs) {
+    return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
+  }
+
+  // These boolean operations operate on the bitwise values of the floating
+  // point inputs.  They return a (vector of) float(s) but like in the mask
+  // generating predicates above this type system oddity makes the kernel IR
+  // generation code less cluttered.
+  llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* FloatAnd(llvm::Value* lhs, float rhs) {
+    return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
+  }
+  llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* FloatOr(llvm::Value* lhs, float rhs) {
+    return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
+  }
+  llvm::Value* FloatNot(llvm::Value* lhs);
+  llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
+    return FloatAnd(FloatNot(lhs), rhs);
+  }
+
+  llvm::Value* BroadcastScalar(llvm::Value* x);
+  llvm::Value* BroadcastScalar(float d) {
+    return BroadcastScalar(GetConstantFloat(scalar_type(), d));
   }
 
   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
@@ -194,6 +236,17 @@ class VectorSupportLibrary {
   std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
       std::vector<llvm::Value*> vectors, llvm::Value* init_values);
 
+  llvm::Type* IntegerTypeForFloatSize(bool vector);
+  llvm::Value* I1ToFloat(llvm::Value* i1);
+  llvm::Value* GetConstantFloat(llvm::Type* type, float f) {
+    llvm::Constant* scalar_value =
+        llvm::ConstantFP::get(type->getContext(), llvm::APFloat(f));
+    if (llvm::isa<llvm::VectorType>(type)) {
+      return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
+    }
+    return scalar_value;
+  }
+
   int64 vector_size_;
   PrimitiveType primitive_type_;
   llvm::IRBuilder<>* ir_builder_;
index 87ac773..7e90050 100644 (file)
@@ -2121,6 +2121,44 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
                              error_spec_);
 }
 
+XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
+  // The input tensor is large enough to exercise the vectorized exp
+  // implementation on XLA CPU.
+  ComputationBuilder builder(client_, TestName());
+
+  std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
+      {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
+       -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
+       198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
+       1.74e+04, 1.89e+05, 1.9e+05,  1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
+       1.66e+07, 1e+07,    1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
+       1.44e+10, 1.5e+10,  1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
+       1.4e+12,  1.03e+13, 1.6e+13,  1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
+       1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
+       2e+18,    1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
+       1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21,  1.35e+22, 1.84e+22, 1.02e+22,
+       1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
+       1.62e+25, 1.2e+26,  1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
+       1.5e+28,  1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30,  1.81e+30, 1.34e+30,
+       1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
+       1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
+                          client_->TransferToServer(*input_literal));
+
+  auto input = builder.Parameter(0, input_literal->shape(), "input");
+  builder.Log(input);
+
+  std::vector<float> expected_result;
+  int64 input_size = input_literal->shape().dimensions(0);
+  expected_result.reserve(input_size);
+  for (int64 i = 0; i < input_size; i++) {
+    expected_result.push_back(std::log(input_literal->Get<float>({i})));
+  }
+
+  ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
+                             error_spec_);
+}
+
 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
   // a ------ (add) --------- (add)
   //         /               /