Register tf.concat with uint8 data type.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 18:14:25 +0000 (11:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 18:16:51 +0000 (11:16 -0700)
PiperOrigin-RevId: 192154998

tensorflow/core/kernels/concat_lib_gpu.cc
tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
tensorflow/core/kernels/concat_op.cc
tensorflow/core/kernels/concat_op_test.cc

index d8643c0..93e392d 100644 (file)
@@ -118,6 +118,7 @@ TF_CALL_complex128(REGISTER);
 TF_CALL_int64(REGISTER);
 TF_CALL_bfloat16(REGISTER);
 TF_CALL_bool(REGISTER);
+TF_CALL_uint8(REGISTER);
 
 #undef REGISTER
 
index 0f7adaf..a561d91 100644 (file)
@@ -202,6 +202,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
 TF_CALL_complex64(REGISTER_GPUCONCAT32);
 TF_CALL_complex128(REGISTER_GPUCONCAT32);
 TF_CALL_int64(REGISTER_GPUCONCAT32);
+TF_CALL_uint8(REGISTER_GPUCONCAT32);
 REGISTER_GPUCONCAT32(bfloat16);
 REGISTER_GPUCONCAT32(bool);
 
@@ -209,6 +210,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
 TF_CALL_complex64(REGISTER_GPUCONCAT64);
 TF_CALL_complex128(REGISTER_GPUCONCAT64);
 TF_CALL_int64(REGISTER_GPUCONCAT64);
+TF_CALL_uint8(REGISTER_GPUCONCAT64);
 REGISTER_GPUCONCAT64(bfloat16);
 REGISTER_GPUCONCAT64(bool);
 
@@ -216,6 +218,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
 TF_CALL_complex64(REGISTER_GPU32);
 TF_CALL_complex128(REGISTER_GPU32);
 TF_CALL_int64(REGISTER_GPU32);
+TF_CALL_uint8(REGISTER_GPU32);
 REGISTER_GPU32(bfloat16);
 REGISTER_GPU32(bool);
 
@@ -223,6 +226,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
 TF_CALL_complex64(REGISTER_GPU64);
 TF_CALL_complex128(REGISTER_GPU64);
 TF_CALL_int64(REGISTER_GPU64);
+TF_CALL_uint8(REGISTER_GPU64);
 REGISTER_GPU64(bfloat16);
 REGISTER_GPU64(bool);
 
index f167663..a87b63f 100644 (file)
@@ -212,6 +212,7 @@ REGISTER_CONCAT(qint32);
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
 REGISTER_GPU(bfloat16);
+TF_CALL_uint8(REGISTER_GPU);
 TF_CALL_complex64(REGISTER_GPU);
 TF_CALL_complex128(REGISTER_GPU);
 TF_CALL_int64(REGISTER_GPU);
index e3ba8ae..39b44b2 100644 (file)
@@ -78,6 +78,9 @@ static void BM_ConcatDim1Float(int iters, int dim2) {
 BENCHMARK(BM_ConcatDim0Float)->Arg(1000)->Arg(100000)->Arg(1000000);
 BENCHMARK(BM_ConcatDim1Float)->Arg(1000)->Arg(100000)->Arg(1000000);
 
+static void BM_ConcatDim1uint8(int iters, int dim2) {
+  ConcatHelper<uint8>(iters, 1, dim2);
+}
 static void BM_ConcatDim1int16(int iters, int dim2) {
   ConcatHelper<int16>(iters, 1, dim2);
 }
@@ -85,6 +88,7 @@ static void BM_ConcatDim1bfloat16(int iters, int dim2) {
   ConcatHelper<bfloat16>(iters, 1, dim2);
 }
 
+BENCHMARK(BM_ConcatDim1uint8)->Arg(1000)->Arg(100000)->Arg(1000000);
 BENCHMARK(BM_ConcatDim1int16)->Arg(1000)->Arg(100000)->Arg(1000000);
 BENCHMARK(BM_ConcatDim1bfloat16)->Arg(1000)->Arg(100000)->Arg(1000000);