* CUB updated to 1.8.0
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Feb 2018 22:22:53 +0000 (14:22 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 22:30:35 +0000 (14:30 -0800)
* updated ShuffleIndex because of API change

PiperOrigin-RevId: 186822637

tensorflow/core/kernels/reduction_gpu_kernels.cu.h
tensorflow/workspace.bzl
third_party/cub/BUILD [deleted file]
third_party/cub/fix_compilation_in_clang.patch [deleted file]

index 15ae4c1fc53b2b9bfe1d6085d2ecbc3659705b47..9237fa51d885c633675146191dc384dd87d8ab22 100644 (file)
@@ -280,8 +280,8 @@ __global__ void ColumnReduceMax16ColumnsKernel(
   const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
   // not the most efficient way to do this sum
   for (int i = 1; i < rows_in_this_warp; ++i) {
-    value_type tmp =
-        cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff);
+    value_type tmp = cub::ShuffleIndex<32, value_type>(
+        sum, static_cast<int>(threadIdx.x + i * num_cols), 0xffffffff);
     if (lane < num_cols) sum = op(sum, tmp);
   }
 
index 2b370ffbacacd95202e3d7cf47175006c2e0efb4..d6ac7be8b5e5042a28d8a065f0b93328dda0847c 100644 (file)
@@ -666,15 +666,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
   tf_http_archive(
       name = "cub_archive",
       urls = [
-          "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip",
-          "https://github.com/NVlabs/cub/archive/1.7.4.zip",
+          "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
+          "https://github.com/NVlabs/cub/archive/1.8.0.zip",
       ],
-      sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31",
-      strip_prefix = "cub-1.7.4",
+      sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+      strip_prefix = "cub-1.8.0",
       build_file = str(Label("//third_party:cub.BUILD")),
-      # TODO: remove the patch when upstream fix is accepted and released.
-      #       PR with a fix: https://github.com/NVlabs/cub/pull/125
-      patch_file = str(Label("//third_party/cub:fix_compilation_in_clang.patch")),
   )
 
   tf_http_archive(
diff --git a/third_party/cub/BUILD b/third_party/cub/BUILD
deleted file mode 100644 (file)
index e69de29..0000000
diff --git a/third_party/cub/fix_compilation_in_clang.patch b/third_party/cub/fix_compilation_in_clang.patch
deleted file mode 100644 (file)
index 384e674..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-From 565b77f7c82048871a4d5e3e506dc663d53cd469 Mon Sep 17 00:00:00 2001
-From: Ilya Biryukov <ibiryukov@google.com>
-Date: Fri, 26 Jan 2018 18:46:06 +0100
-Subject: [PATCH] Added missing 'template' keyword.
-
-To unbreak compilation with clang.
----
- cub/device/dispatch/dispatch_radix_sort.cuh | 2 +-
- 1 file changed, 1 insertion(+), 1 deletion(-)
-
-diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh
-index 7fbc621f..f622e212 100644
---- a/cub/device/dispatch/dispatch_radix_sort.cuh
-+++ b/cub/device/dispatch/dispatch_radix_sort.cuh
-@@ -104,7 +104,7 @@ __global__ void DeviceRadixSortUpsweepKernel(
-     CTA_SYNC();
-     // Write out digit counts (striped)
--    upsweep.ExtractCounts<IS_DESCENDING>(d_spine, gridDim.x, blockIdx.x);
-+    upsweep.template ExtractCounts<IS_DESCENDING>(d_spine, gridDim.x, blockIdx.x);
- }