const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
// not the most efficient way to do this sum
for (int i = 1; i < rows_in_this_warp; ++i) {
- value_type tmp =
- cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff);
+ value_type tmp = cub::ShuffleIndex<32, value_type>(
+ sum, static_cast<int>(threadIdx.x + i * num_cols), 0xffffffff);
if (lane < num_cols) sum = op(sum, tmp);
}
tf_http_archive(
name = "cub_archive",
urls = [
- "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip",
- "https://github.com/NVlabs/cub/archive/1.7.4.zip",
+ "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
+ "https://github.com/NVlabs/cub/archive/1.8.0.zip",
],
- sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31",
- strip_prefix = "cub-1.7.4",
+ sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+ strip_prefix = "cub-1.8.0",
build_file = str(Label("//third_party:cub.BUILD")),
- # TODO: remove the patch when upstream fix is accepted and released.
- # PR with a fix: https://github.com/NVlabs/cub/pull/125
- patch_file = str(Label("//third_party/cub:fix_compilation_in_clang.patch")),
)
tf_http_archive(
+++ /dev/null
-From 565b77f7c82048871a4d5e3e506dc663d53cd469 Mon Sep 17 00:00:00 2001
-From: Ilya Biryukov <ibiryukov@google.com>
-Date: Fri, 26 Jan 2018 18:46:06 +0100
-Subject: [PATCH] Added missing 'template' keyword.
-
-To unbreak compilation with clang.
----
- cub/device/dispatch/dispatch_radix_sort.cuh | 2 +-
- 1 file changed, 1 insertion(+), 1 deletion(-)
-
-diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh
-index 7fbc621f..f622e212 100644
---- a/cub/device/dispatch/dispatch_radix_sort.cuh
-+++ b/cub/device/dispatch/dispatch_radix_sort.cuh
-@@ -104,7 +104,7 @@ __global__ void DeviceRadixSortUpsweepKernel(
- CTA_SYNC();
-
- // Write out digit counts (striped)
-- upsweep.ExtractCounts<IS_DESCENDING>(d_spine, gridDim.x, blockIdx.x);
-+ upsweep.template ExtractCounts<IS_DESCENDING>(d_spine, gridDim.x, blockIdx.x);
- }
-
-