From: A. Unique TensorFlower Date: Mon, 9 Apr 2018 08:59:50 +0000 (-0700) Subject: Enabling fp16 for NCCL 1 and 2. X-Git-Tag: tflite-v0.1.7~16^2^2~46 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5c469e6bafb479ef110b2f02f070507a3711664d;p=platform%2Fupstream%2Ftensorflow.git Enabling fp16 for NCCL 1 and 2. PiperOrigin-RevId: 192096789 --- diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index 913935b..b9b482a 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -76,6 +76,8 @@ struct NcclManager::Communicator { namespace { ncclDataType_t ToNcclType(DataType t) { switch (t) { + case DT_HALF: + return ncclHalf; case DT_FLOAT: return ncclFloat; case DT_DOUBLE: diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index 985b2ba..06ca65e 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -48,35 +48,9 @@ static std::vector GetGPUDevices() { return gpus; } +template class NcclManagerTest : public ::testing::Test { - protected: - static void SetUpTestCase() { - setenv("NCCL_DEBUG", "INFO", 1 /* replace */); - devices = new std::vector(GetGPUDevices()); - CHECK(!devices->empty()); - LOG(ERROR) << "Running test with " << devices->size() << " gpus"; - } - static void TearDownTestCase() { - for (auto device : *devices) delete device; - delete devices; - } - - static Allocator* gpu_allocator(BaseGPUDevice* device) { - return device->GetStepAllocator(AllocatorAttributes(), - nullptr /* step_resource_manager */); - } - - static std::vector* devices; - - template - perftools::gputools::DeviceMemory AsDeviceMemory( - const Scalar* cuda_memory) { - perftools::gputools::DeviceMemoryBase wrapped( - const_cast(cuda_memory)); - perftools::gputools::DeviceMemory typed(wrapped); - return typed; - } - + public: // A single all-reduce to apply. struct TestCase { string key; @@ -89,42 +63,52 @@ class NcclManagerTest : public ::testing::Test { int num_completed = 0; }; + static void SetUpTestCase() { + setenv("NCCL_DEBUG", "INFO", 1 /* replace */); + devices_ = new std::vector(GetGPUDevices()); + CHECK(!devices_->empty()); + LOG(ERROR) << "Running test with " << devices_->size() << " gpus"; + } + + static void TearDownTestCase() { + for (auto device : *devices_) delete device; + delete devices_; + } + TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op, TensorShape shape, float value_offset) { TestCase* test_case = new TestCase(); - test_case->expected = Tensor(DT_FLOAT, shape); + test_case->expected = Tensor(data_type_, shape); if (reduction_op == ncclProd) { - test::FillFn(&test_case->expected, [](int) { return 1; }); + test::FillFn(&test_case->expected, + [](int) { return static_cast(1); }); } else if (reduction_op == ncclSum) { - test::FillFn(&test_case->expected, [](int) { return 0; }); + test::FillFn(&test_case->expected, + [](int) { return static_cast(0); }); } else if (reduction_op == ncclMax) { - test::FillFn(&test_case->expected, [](int) { - return -1 * std::numeric_limits::max(); - }); + test::FillFn(&test_case->expected, [](int) { return -max_; }); } else if (reduction_op == ncclMin) { - test::FillFn(&test_case->expected, [](int) { - return std::numeric_limits::max(); - }); + test::FillFn(&test_case->expected, [](int) { return max_; }); } else { LOG(FATAL) << "Invalid reduction_op " << reduction_op; } - int mult = 1; - for (int i = 0; i < num_ranks; ++i) { - auto* device = devices->at(i % devices->size()); + float value_scale = 0.01; // Small scale to avoid fp16 overflow. + for (int rank = 0; rank < num_ranks; ++rank) { + auto* device = GetDevice(rank); auto* stream = device->tensorflow_gpu_device_info()->stream; - Tensor in_cpu(DT_FLOAT, shape); - test::FillFn(&in_cpu, [mult, value_offset](int index) { - return value_offset + (index + 1) * mult; + Tensor in_cpu(data_type_, shape); + test::FillFn(&in_cpu, [&](int index) { + return static_cast((index + 1) * value_scale + value_offset); }); for (int j = 0; j < shape.num_elements(); ++j) { - auto in_val = in_cpu.flat()(j); - auto out_expr = test_case->expected.flat(); + auto in_val = in_cpu.flat()(j); + auto out_expr = test_case->expected.template flat(); if (reduction_op == ncclProd) { - out_expr(j) *= in_val; + out_expr(j) = out_expr(j) * in_val; } else if (reduction_op == ncclSum) { - out_expr(j) += in_val; + out_expr(j) = out_expr(j) + in_val; } else if (reduction_op == ncclMax) { if (in_val > out_expr(j)) { out_expr(j) = in_val; @@ -136,26 +120,18 @@ class NcclManagerTest : public ::testing::Test { } } - mult *= 10; - test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape); - test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape); + value_scale *= 10; + test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape); + test_case->outs.emplace_back(GpuAllocator(device), data_type_, shape); const Tensor& in_gpu = test_case->ins.back(); - auto in_gpu_mem = AsDeviceMemory(in_gpu.flat().data()); - stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat().data(), + auto in_gpu_mem = AsDeviceMemory(in_gpu.flat().data()); + stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat().data(), in_cpu.TotalBytes()); } return test_case; } - NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) { - return [this, test_case](Status s) { - mutex_lock l(test_case->mu); - ++test_case->num_completed; - test_case->final_status.Update(s); - }; - } - void VerifyResults(const string& case_label, TestCase* test_case) { // Wait for the done callback to be called. { @@ -168,41 +144,84 @@ class NcclManagerTest : public ::testing::Test { test_case->mu.unlock(); } // Copy memory to host and verify. - for (int i = 0; i < test_case->outs.size(); ++i) { - auto* device = devices->at(i % devices->size()); + for (int rank = 0; rank < test_case->outs.size(); ++rank) { + auto* device = GetDevice(rank); auto* stream = device->tensorflow_gpu_device_info()->stream; - const Tensor& out_gpu = test_case->outs[i]; - Tensor out_cpu(DT_FLOAT, out_gpu.shape()); - auto out_gpu_mem = AsDeviceMemory(out_gpu.flat().data()); - stream->ThenMemcpy(out_cpu.flat().data(), out_gpu_mem, + const Tensor& out_gpu = test_case->outs[rank]; + Tensor out_cpu(data_type_, out_gpu.shape()); + auto out_gpu_mem = AsDeviceMemory(out_gpu.flat().data()); + stream->ThenMemcpy(out_cpu.flat().data(), out_gpu_mem, out_cpu.TotalBytes()); SE_ASSERT_OK(stream->BlockHostUntilDone()); - test::ExpectTensorEqual(test_case->expected, out_cpu); + test::ExpectTensorNear(test_case->expected, out_cpu, 0.01); } } + + NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) { + return [this, test_case](Status s) { + mutex_lock l(test_case->mu); + ++test_case->num_completed; + test_case->final_status.Update(s); + }; + } + + static BaseGPUDevice* GetDevice(size_t rank) { + return devices_->at(rank % devices_->size()); + } + + private: + static Allocator* GpuAllocator(BaseGPUDevice* device) { + return device->GetStepAllocator(AllocatorAttributes(), + nullptr /* step_resource_manager */); + } + + static perftools::gputools::DeviceMemory AsDeviceMemory( + const Scalar* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped( + const_cast(cuda_memory)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; + } + + private: + static std::vector* devices_; + static const DataType data_type_; + static const Scalar max_; }; -std::vector* NcclManagerTest::devices = nullptr; + +template +std::vector* NcclManagerTest::devices_ = nullptr; +template +const DataType NcclManagerTest::data_type_ = + DataTypeToEnum::value; +template +const Scalar NcclManagerTest::max_ = + Eigen::NumTraits::highest(); + +// Instantiate tests for float and half. +using TypeList = ::testing::Types; +TYPED_TEST_CASE(NcclManagerTest, TypeList); // Test basic sum reduction. -TEST_F(NcclManagerTest, BasicSumReduction) { +TYPED_TEST(NcclManagerTest, BasicSumReduction) { const int num_ranks = 3; for (int op = 0; op < 4; ++op) { ncclRedOp_t reduction_op = static_cast(op); - std::unique_ptr test_case( - MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0)); - for (int device_num = 0; device_num < num_ranks; ++device_num) { - auto* device = devices->at(device_num % devices->size()); + std::unique_ptr test_case( + this->MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0.0f)); + for (int rank = 0; rank < num_ranks; ++rank) { + auto* device = this->GetDevice(rank); auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; auto* stream = device->tensorflow_gpu_device_info()->stream; NcclManager::instance()->AddToAllReduce( num_ranks, "allreduce", reduction_op, device->executor(), - device->gpu_id(), event_mgr, stream, &test_case->ins[device_num], - &test_case->outs[device_num], CreateDoneCallback(test_case.get())); + device->gpu_id(), event_mgr, stream, &test_case->ins[rank], + &test_case->outs[rank], this->CreateDoneCallback(test_case.get())); } LOG(ERROR) << "Verifying results"; - VerifyResults("test_case", test_case.get()); + this->VerifyResults("test_case", test_case.get()); } } @@ -213,7 +232,7 @@ TEST_F(NcclManagerTest, BasicSumReduction) { // with num_ranks > devices->size(), for some GPUs (e.g. K20m). // To test the higher settings, increase num_ranks, // num_collectives_per_iteration and time_limit_micros. -TEST_F(NcclManagerTest, MultipleCallers) { +TYPED_TEST(NcclManagerTest, MultipleCallers) { const int num_ranks = 1; // 2; const int num_collectives_per_iteration = 1; // 1000; const int num_threads = 3; @@ -223,49 +242,49 @@ TEST_F(NcclManagerTest, MultipleCallers) { srand(Env::Default()->NowMicros()); for (;;) { - std::vector> case_and_device_num; - std::vector> test_cases; + std::vector> case_and_rank; + std::vector> test_cases; for (int i = 0; i < num_collectives_per_iteration; ++i) { - test_cases.emplace_back( - MakeTestCase(num_ranks, ncclSum, - TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i)); + test_cases.emplace_back(this->MakeTestCase( + num_ranks, ncclSum, TensorShape({100, i % 5 + 1, i % 3 + 1}), + 1.1f * i)); for (int j = 0; j < num_ranks; ++j) { - case_and_device_num.emplace_back(i, j); + case_and_rank.emplace_back(i, j); } } - for (int i = 0; i < num_ranks; ++i) { - auto* device = devices->at(i % devices->size()); + for (int rank = 0; rank < num_ranks; ++rank) { + auto* device = this->GetDevice(rank); auto* stream = device->tensorflow_gpu_device_info()->stream; SE_ASSERT_OK(stream->BlockHostUntilDone()); } - std::shuffle(case_and_device_num.begin(), case_and_device_num.end(), + std::shuffle(case_and_rank.begin(), case_and_rank.end(), std::mt19937(std::random_device()())); - mutex mu; // guards case_and_device_num. + mutex mu; // guards case_and_rank. std::unique_ptr pool( new thread::ThreadPool(Env::Default(), "test", num_threads)); - const int to_schedule = case_and_device_num.size(); + const int to_schedule = case_and_rank.size(); for (int i = 0; i < to_schedule; ++i) { auto fn = [&]() { - int device_num; + int rank; int test_num; { mutex_lock l(mu); - test_num = case_and_device_num.back().first; - device_num = case_and_device_num.back().second; - case_and_device_num.pop_back(); + test_num = case_and_rank.back().first; + rank = case_and_rank.back().second; + case_and_rank.pop_back(); } - auto* device = devices->at(device_num % devices->size()); + auto* device = this->GetDevice(rank); auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; auto* stream = device->tensorflow_gpu_device_info()->stream; - TestCase* test_case = test_cases[test_num].get(); + typename TestFixture::TestCase* test_case = test_cases[test_num].get(); NcclManager::instance()->AddToAllReduce( num_ranks, strings::StrCat("allreduce", test_num), ncclSum, device->executor(), device->gpu_id(), event_mgr, stream, - &test_case->ins[device_num], &test_case->outs[device_num], - CreateDoneCallback(test_case)); + &test_case->ins[rank], &test_case->outs[rank], + this->CreateDoneCallback(test_case)); }; pool->Schedule(fn); } @@ -274,7 +293,8 @@ TEST_F(NcclManagerTest, MultipleCallers) { LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration << " collectives"; for (int i = 0; i < test_cases.size(); ++i) { - VerifyResults(strings::StrCat("collective", i), test_cases[i].get()); + this->VerifyResults(strings::StrCat("collective", i), + test_cases[i].get()); } int64 delta = Env::Default()->NowMicros() - start; diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc index 8eb804c..a353a34 100644 --- a/tensorflow/contrib/nccl/ops/nccl_ops.cc +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -25,7 +25,7 @@ REGISTER_OP("NcclAllReduce") .Input("input: T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("num_devices: int") .Attr("shared_name: string") .SetIsStateful() @@ -51,7 +51,7 @@ REGISTER_OP("NcclReduce") .Input("input: num_devices * T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("num_devices: int") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape) @@ -69,7 +69,7 @@ reduction: the reduction operation to perform. REGISTER_OP("_NcclReduceSend") .Input("input: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("num_devices: int") .Attr("shared_name: string") .SetIsStateful() @@ -92,7 +92,7 @@ REGISTER_OP("_NcclReduceRecv") .Input("input: T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("num_devices: int") .Attr("shared_name: string") .SetIsStateful() @@ -118,7 +118,7 @@ shared_name: Identifier that is shared between ops of the same reduce. REGISTER_OP("NcclBroadcast") .Input("input: T") .Output("output: T") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("shape: shape") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape) @@ -135,7 +135,7 @@ shape: The shape of the input tensor. REGISTER_OP("_NcclBroadcastSend") .Input("input: T") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("num_devices: int") .Attr("shared_name: string") .SetIsStateful() @@ -157,7 +157,7 @@ shared_name: Identifier that is shared between ops of the same broadcast. REGISTER_OP("_NcclBroadcastRecv") .Input("shape: int32") .Output("output: T") - .Attr("T: {float, float64, int32, int64}") + .Attr("T: {half, float, float64, int32, int64}") .Attr("num_devices: int") .Attr("shared_name: string") .SetIsStateful() diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 98fe394..423a868 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -72,7 +72,7 @@ class NcclTestCase(test.TestCase): two. device_sets: Tuple of virtual devices to run test on. """ - for dtype in [np.float32, np.int32, np.int64, np.float64]: + for dtype in [np.float16, np.float32, np.int32, np.int64, np.float64]: # Create session inside outer loop to test use of # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: