Enabling fp16 for NCCL 1 and 2.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 08:59:50 +0000 (01:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 09:02:12 +0000 (02:02 -0700)
PiperOrigin-RevId: 192096789

tensorflow/contrib/nccl/kernels/nccl_manager.cc
tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
tensorflow/contrib/nccl/ops/nccl_ops.cc
tensorflow/contrib/nccl/python/ops/nccl_ops_test.py

index 913935b..b9b482a 100644 (file)
@@ -76,6 +76,8 @@ struct NcclManager::Communicator {
 namespace {
 ncclDataType_t ToNcclType(DataType t) {
   switch (t) {
+    case DT_HALF:
+      return ncclHalf;
     case DT_FLOAT:
       return ncclFloat;
     case DT_DOUBLE:
index 985b2ba..06ca65e 100644 (file)
@@ -48,35 +48,9 @@ static std::vector<BaseGPUDevice*> GetGPUDevices() {
   return gpus;
 }
 
+template <typename Scalar>
 class NcclManagerTest : public ::testing::Test {
- protected:
-  static void SetUpTestCase() {
-    setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
-    devices = new std::vector<BaseGPUDevice*>(GetGPUDevices());
-    CHECK(!devices->empty());
-    LOG(ERROR) << "Running test with " << devices->size() << " gpus";
-  }
-  static void TearDownTestCase() {
-    for (auto device : *devices) delete device;
-    delete devices;
-  }
-
-  static Allocator* gpu_allocator(BaseGPUDevice* device) {
-    return device->GetStepAllocator(AllocatorAttributes(),
-                                    nullptr /* step_resource_manager */);
-  }
-
-  static std::vector<BaseGPUDevice*>* devices;
-
-  template <typename Scalar>
-  perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
-      const Scalar* cuda_memory) {
-    perftools::gputools::DeviceMemoryBase wrapped(
-        const_cast<Scalar*>(cuda_memory));
-    perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
-    return typed;
-  }
-
+ public:
   // A single all-reduce to apply.
   struct TestCase {
     string key;
@@ -89,42 +63,52 @@ class NcclManagerTest : public ::testing::Test {
     int num_completed = 0;
   };
 
+  static void SetUpTestCase() {
+    setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
+    devices_ = new std::vector<BaseGPUDevice*>(GetGPUDevices());
+    CHECK(!devices_->empty());
+    LOG(ERROR) << "Running test with " << devices_->size() << " gpus";
+  }
+
+  static void TearDownTestCase() {
+    for (auto device : *devices_) delete device;
+    delete devices_;
+  }
+
   TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op,
                          TensorShape shape, float value_offset) {
     TestCase* test_case = new TestCase();
-    test_case->expected = Tensor(DT_FLOAT, shape);
+    test_case->expected = Tensor(data_type_, shape);
     if (reduction_op == ncclProd) {
-      test::FillFn<float>(&test_case->expected, [](int) { return 1; });
+      test::FillFn<Scalar>(&test_case->expected,
+                           [](int) { return static_cast<Scalar>(1); });
     } else if (reduction_op == ncclSum) {
-      test::FillFn<float>(&test_case->expected, [](int) { return 0; });
+      test::FillFn<Scalar>(&test_case->expected,
+                           [](int) { return static_cast<Scalar>(0); });
     } else if (reduction_op == ncclMax) {
-      test::FillFn<float>(&test_case->expected, [](int) {
-        return -1 * std::numeric_limits<float>::max();
-      });
+      test::FillFn<Scalar>(&test_case->expected, [](int) { return -max_; });
     } else if (reduction_op == ncclMin) {
-      test::FillFn<float>(&test_case->expected, [](int) {
-        return std::numeric_limits<float>::max();
-      });
+      test::FillFn<Scalar>(&test_case->expected, [](int) { return max_; });
     } else {
       LOG(FATAL) << "Invalid reduction_op " << reduction_op;
     }
 
-    int mult = 1;
-    for (int i = 0; i < num_ranks; ++i) {
-      auto* device = devices->at(i % devices->size());
+    float value_scale = 0.01;  // Small scale to avoid fp16 overflow.
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      auto* device = GetDevice(rank);
       auto* stream = device->tensorflow_gpu_device_info()->stream;
 
-      Tensor in_cpu(DT_FLOAT, shape);
-      test::FillFn<float>(&in_cpu, [mult, value_offset](int index) {
-        return value_offset + (index + 1) * mult;
+      Tensor in_cpu(data_type_, shape);
+      test::FillFn<Scalar>(&in_cpu, [&](int index) {
+        return static_cast<Scalar>((index + 1) * value_scale + value_offset);
       });
       for (int j = 0; j < shape.num_elements(); ++j) {
-        auto in_val = in_cpu.flat<float>()(j);
-        auto out_expr = test_case->expected.flat<float>();
+        auto in_val = in_cpu.flat<Scalar>()(j);
+        auto out_expr = test_case->expected.template flat<Scalar>();
         if (reduction_op == ncclProd) {
-          out_expr(j) *= in_val;
+          out_expr(j) = out_expr(j) * in_val;
         } else if (reduction_op == ncclSum) {
-          out_expr(j) += in_val;
+          out_expr(j) = out_expr(j) + in_val;
         } else if (reduction_op == ncclMax) {
           if (in_val > out_expr(j)) {
             out_expr(j) = in_val;
@@ -136,26 +120,18 @@ class NcclManagerTest : public ::testing::Test {
         }
       }
 
-      mult *= 10;
-      test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
-      test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
+      value_scale *= 10;
+      test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape);
+      test_case->outs.emplace_back(GpuAllocator(device), data_type_, shape);
 
       const Tensor& in_gpu = test_case->ins.back();
-      auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data());
-      stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(),
+      auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<Scalar>().data());
+      stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<Scalar>().data(),
                          in_cpu.TotalBytes());
     }
     return test_case;
   }
 
-  NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
-    return [this, test_case](Status s) {
-      mutex_lock l(test_case->mu);
-      ++test_case->num_completed;
-      test_case->final_status.Update(s);
-    };
-  }
-
   void VerifyResults(const string& case_label, TestCase* test_case) {
     // Wait for the done callback to be called.
     {
@@ -168,41 +144,84 @@ class NcclManagerTest : public ::testing::Test {
       test_case->mu.unlock();
     }
     // Copy memory to host and verify.
-    for (int i = 0; i < test_case->outs.size(); ++i) {
-      auto* device = devices->at(i % devices->size());
+    for (int rank = 0; rank < test_case->outs.size(); ++rank) {
+      auto* device = GetDevice(rank);
       auto* stream = device->tensorflow_gpu_device_info()->stream;
-      const Tensor& out_gpu = test_case->outs[i];
-      Tensor out_cpu(DT_FLOAT, out_gpu.shape());
-      auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data());
-      stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem,
+      const Tensor& out_gpu = test_case->outs[rank];
+      Tensor out_cpu(data_type_, out_gpu.shape());
+      auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<Scalar>().data());
+      stream->ThenMemcpy(out_cpu.flat<Scalar>().data(), out_gpu_mem,
                          out_cpu.TotalBytes());
       SE_ASSERT_OK(stream->BlockHostUntilDone());
-      test::ExpectTensorEqual<float>(test_case->expected, out_cpu);
+      test::ExpectTensorNear<Scalar>(test_case->expected, out_cpu, 0.01);
     }
   }
+
+  NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
+    return [this, test_case](Status s) {
+      mutex_lock l(test_case->mu);
+      ++test_case->num_completed;
+      test_case->final_status.Update(s);
+    };
+  }
+
+  static BaseGPUDevice* GetDevice(size_t rank) {
+    return devices_->at(rank % devices_->size());
+  }
+
+ private:
+  static Allocator* GpuAllocator(BaseGPUDevice* device) {
+    return device->GetStepAllocator(AllocatorAttributes(),
+                                    nullptr /* step_resource_manager */);
+  }
+
+  static perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
+      const Scalar* cuda_memory) {
+    perftools::gputools::DeviceMemoryBase wrapped(
+        const_cast<Scalar*>(cuda_memory));
+    perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+    return typed;
+  }
+
+ private:
+  static std::vector<BaseGPUDevice*>* devices_;
+  static const DataType data_type_;
+  static const Scalar max_;
 };
-std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr;
+
+template <typename Scalar>
+std::vector<BaseGPUDevice*>* NcclManagerTest<Scalar>::devices_ = nullptr;
+template <typename Scalar>
+const DataType NcclManagerTest<Scalar>::data_type_ =
+    DataTypeToEnum<Scalar>::value;
+template <typename Scalar>
+const Scalar NcclManagerTest<Scalar>::max_ =
+    Eigen::NumTraits<Scalar>::highest();
+
+// Instantiate tests for float and half.
+using TypeList = ::testing::Types<float, Eigen::half>;
+TYPED_TEST_CASE(NcclManagerTest, TypeList);
 
 // Test basic sum reduction.
-TEST_F(NcclManagerTest, BasicSumReduction) {
+TYPED_TEST(NcclManagerTest, BasicSumReduction) {
   const int num_ranks = 3;
 
   for (int op = 0; op < 4; ++op) {
     ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
-    std::unique_ptr<TestCase> test_case(
-        MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0));
-    for (int device_num = 0; device_num < num_ranks; ++device_num) {
-      auto* device = devices->at(device_num % devices->size());
+    std::unique_ptr<typename TestFixture::TestCase> test_case(
+        this->MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0.0f));
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      auto* device = this->GetDevice(rank);
       auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
       auto* stream = device->tensorflow_gpu_device_info()->stream;
       NcclManager::instance()->AddToAllReduce(
           num_ranks, "allreduce", reduction_op, device->executor(),
-          device->gpu_id(), event_mgr, stream, &test_case->ins[device_num],
-          &test_case->outs[device_num], CreateDoneCallback(test_case.get()));
+          device->gpu_id(), event_mgr, stream, &test_case->ins[rank],
+          &test_case->outs[rank], this->CreateDoneCallback(test_case.get()));
     }
 
     LOG(ERROR) << "Verifying results";
-    VerifyResults("test_case", test_case.get());
+    this->VerifyResults("test_case", test_case.get());
   }
 }
 
@@ -213,7 +232,7 @@ TEST_F(NcclManagerTest, BasicSumReduction) {
 // with num_ranks > devices->size(), for some GPUs (e.g. K20m).
 // To test the higher settings, increase num_ranks,
 // num_collectives_per_iteration and time_limit_micros.
-TEST_F(NcclManagerTest, MultipleCallers) {
+TYPED_TEST(NcclManagerTest, MultipleCallers) {
   const int num_ranks = 1;                      // 2;
   const int num_collectives_per_iteration = 1;  // 1000;
   const int num_threads = 3;
@@ -223,49 +242,49 @@ TEST_F(NcclManagerTest, MultipleCallers) {
   srand(Env::Default()->NowMicros());
 
   for (;;) {
-    std::vector<std::pair<int, int>> case_and_device_num;
-    std::vector<std::unique_ptr<TestCase>> test_cases;
+    std::vector<std::pair<int, int>> case_and_rank;
+    std::vector<std::unique_ptr<typename TestFixture::TestCase>> test_cases;
     for (int i = 0; i < num_collectives_per_iteration; ++i) {
-      test_cases.emplace_back(
-          MakeTestCase(num_ranks, ncclSum,
-                       TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i));
+      test_cases.emplace_back(this->MakeTestCase(
+          num_ranks, ncclSum, TensorShape({100, i % 5 + 1, i % 3 + 1}),
+          1.1f * i));
       for (int j = 0; j < num_ranks; ++j) {
-        case_and_device_num.emplace_back(i, j);
+        case_and_rank.emplace_back(i, j);
       }
     }
 
-    for (int i = 0; i < num_ranks; ++i) {
-      auto* device = devices->at(i % devices->size());
+    for (int rank = 0; rank < num_ranks; ++rank) {
+      auto* device = this->GetDevice(rank);
       auto* stream = device->tensorflow_gpu_device_info()->stream;
       SE_ASSERT_OK(stream->BlockHostUntilDone());
     }
 
-    std::shuffle(case_and_device_num.begin(), case_and_device_num.end(),
+    std::shuffle(case_and_rank.begin(), case_and_rank.end(),
                  std::mt19937(std::random_device()()));
 
-    mutex mu;  // guards case_and_device_num.
+    mutex mu;  // guards case_and_rank.
     std::unique_ptr<thread::ThreadPool> pool(
         new thread::ThreadPool(Env::Default(), "test", num_threads));
-    const int to_schedule = case_and_device_num.size();
+    const int to_schedule = case_and_rank.size();
     for (int i = 0; i < to_schedule; ++i) {
       auto fn = [&]() {
-        int device_num;
+        int rank;
         int test_num;
         {
           mutex_lock l(mu);
-          test_num = case_and_device_num.back().first;
-          device_num = case_and_device_num.back().second;
-          case_and_device_num.pop_back();
+          test_num = case_and_rank.back().first;
+          rank = case_and_rank.back().second;
+          case_and_rank.pop_back();
         }
-        auto* device = devices->at(device_num % devices->size());
+        auto* device = this->GetDevice(rank);
         auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
         auto* stream = device->tensorflow_gpu_device_info()->stream;
-        TestCase* test_case = test_cases[test_num].get();
+        typename TestFixture::TestCase* test_case = test_cases[test_num].get();
         NcclManager::instance()->AddToAllReduce(
             num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
             device->executor(), device->gpu_id(), event_mgr, stream,
-            &test_case->ins[device_num], &test_case->outs[device_num],
-            CreateDoneCallback(test_case));
+            &test_case->ins[rank], &test_case->outs[rank],
+            this->CreateDoneCallback(test_case));
       };
       pool->Schedule(fn);
     }
@@ -274,7 +293,8 @@ TEST_F(NcclManagerTest, MultipleCallers) {
     LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration
                << " collectives";
     for (int i = 0; i < test_cases.size(); ++i) {
-      VerifyResults(strings::StrCat("collective", i), test_cases[i].get());
+      this->VerifyResults(strings::StrCat("collective", i),
+                          test_cases[i].get());
     }
 
     int64 delta = Env::Default()->NowMicros() - start;
index 8eb804c..a353a34 100644 (file)
@@ -25,7 +25,7 @@ REGISTER_OP("NcclAllReduce")
     .Input("input: T")
     .Output("data: T")
     .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("num_devices: int")
     .Attr("shared_name: string")
     .SetIsStateful()
@@ -51,7 +51,7 @@ REGISTER_OP("NcclReduce")
     .Input("input: num_devices * T")
     .Output("data: T")
     .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("num_devices: int")
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnchangedShape)
@@ -69,7 +69,7 @@ reduction: the reduction operation to perform.
 REGISTER_OP("_NcclReduceSend")
     .Input("input: T")
     .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("num_devices: int")
     .Attr("shared_name: string")
     .SetIsStateful()
@@ -92,7 +92,7 @@ REGISTER_OP("_NcclReduceRecv")
     .Input("input: T")
     .Output("data: T")
     .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("num_devices: int")
     .Attr("shared_name: string")
     .SetIsStateful()
@@ -118,7 +118,7 @@ shared_name: Identifier that is shared between ops of the same reduce.
 REGISTER_OP("NcclBroadcast")
     .Input("input: T")
     .Output("output: T")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("shape: shape")
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnchangedShape)
@@ -135,7 +135,7 @@ shape: The shape of the input tensor.
 
 REGISTER_OP("_NcclBroadcastSend")
     .Input("input: T")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("num_devices: int")
     .Attr("shared_name: string")
     .SetIsStateful()
@@ -157,7 +157,7 @@ shared_name: Identifier that is shared between ops of the same broadcast.
 REGISTER_OP("_NcclBroadcastRecv")
     .Input("shape: int32")
     .Output("output: T")
-    .Attr("T: {float, float64, int32, int64}")
+    .Attr("T: {half, float, float64, int32, int64}")
     .Attr("num_devices: int")
     .Attr("shared_name: string")
     .SetIsStateful()
index 98fe394..423a868 100644 (file)
@@ -72,7 +72,7 @@ class NcclTestCase(test.TestCase):
           two.
       device_sets: Tuple of virtual devices to run test on.
     """
-    for dtype in [np.float32, np.int32, np.int64, np.float64]:
+    for dtype in [np.float16, np.float32, np.int32, np.int64, np.float64]:
       # Create session inside outer loop to test use of
       # same communicator across multiple sessions.
       with self.test_session(use_gpu=True) as sess: