Add bfloat16 support for CPU ops.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 19 Mar 2018 20:26:19 +0000 (13:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 20:31:10 +0000 (13:31 -0700)
PiperOrigin-RevId: 189631659

tensorflow/core/kernels/cwise_op_div.cc
tensorflow/core/kernels/cwise_op_less.cc
tensorflow/core/kernels/cwise_op_less_equal.cc
tensorflow/core/kernels/cwise_op_minimum.cc
tensorflow/core/kernels/cwise_op_sqrt.cc

index c71c756..b12652f 100644 (file)
@@ -16,14 +16,14 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER5(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double,
-          complex64, complex128);
+REGISTER6(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double,
+          bfloat16, complex64, complex128);
 REGISTER5(BinaryOp, CPU, "Div", functor::safe_div, uint8, uint16, int16, int32,
           int64);
 REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
           int32, int64);
-REGISTER5(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
-          complex64, complex128);
+REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
+          bfloat16, complex64, complex128);
 #if GOOGLE_CUDA
 REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
           uint16, int16, int64, complex64, complex128);
index 00cdecd..5759681 100644 (file)
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER8(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
-          int32, int64, uint8, int8, int16);
+REGISTER9(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
+          bfloat16, int32, int64, uint8, int8, int16);
 #if GOOGLE_CUDA
 REGISTER7(BinaryOp, GPU, "Less", functor::less, float, Eigen::half, double,
           int64, uint8, int8, int16);
index 11806c5..499200d 100644 (file)
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER8(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
-          double, int32, int64, uint8, int8, int16);
+REGISTER9(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
+          bfloat16, double, int32, int64, uint8, int8, int16);
 #if GOOGLE_CUDA
 REGISTER7(BinaryOp, GPU, "LessEqual", functor::less_equal, float, Eigen::half,
           double, int64, uint8, int8, int16);
index dff83df..9bc3700 100644 (file)
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER5(BinaryOp, CPU, "Minimum", functor::minimum, float, Eigen::half,
-          double, int32, int64);
+REGISTER6(BinaryOp, CPU, "Minimum", functor::minimum, float, Eigen::half,
+          bfloat16, double, int32, int64);
 #if GOOGLE_CUDA
 REGISTER4(BinaryOp, GPU, "Minimum", functor::minimum, float, Eigen::half,
           double, int64);
index 4977561..2050707 100644 (file)
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double,
-          complex64, complex128);
+REGISTER6(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double,
+          bfloat16, complex64, complex128);
 
 #if GOOGLE_CUDA
 REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double);
@@ -27,8 +27,8 @@ REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double);
 REGISTER2(UnaryOp, SYCL, "Sqrt", functor::sqrt, float, double);
 #endif  // TENSORFLOW_USE_SYCL
 
-REGISTER5(SimpleBinaryOp, CPU, "SqrtGrad", functor::sqrt_grad, float,
-          Eigen::half, double, complex64, complex128);
+REGISTER6(SimpleBinaryOp, CPU, "SqrtGrad", functor::sqrt_grad, float,
+          Eigen::half, bfloat16, double, complex64, complex128);
 #if GOOGLE_CUDA
 REGISTER3(SimpleBinaryOp, GPU, "SqrtGrad", functor::sqrt_grad, float,
           Eigen::half, double);