Adding BFP16 quantization/dequantization support to OSS (#63059)
authorMarjan Fariborz <marjanf@fb.com>
Thu, 26 Aug 2021 06:40:09 +0000 (23:40 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 06:41:34 +0000 (23:41 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63059

Supporting BFP16 quantization method to OSS. Currently only support CPU
ghstack-source-id: 136639528

Test Plan: Imported from OSS

Reviewed By: wanchaol

Differential Revision: D30194538

fbshipit-source-id: ac248567ad8028457c2a91b77ef2ce81709fce53

test/distributed/algorithms/quantization/test_quantization.py
tools/build_variables.bzl
torch/csrc/distributed/c10d/init.cpp
torch/csrc/distributed/c10d/quantization/quantization.cpp [new file with mode: 0644]
torch/csrc/distributed/c10d/quantization/quantization.h [new file with mode: 0644]
torch/csrc/distributed/c10d/quantization/quantization_gpu.cu [new file with mode: 0644]
torch/csrc/distributed/c10d/quantization/quantization_gpu.h [new file with mode: 0644]
torch/csrc/distributed/c10d/quantization/quantization_utils.h [new file with mode: 0644]
torch/distributed/algorithms/quantization/quantization.py

index 7872920..505f805 100644 (file)
@@ -8,6 +8,7 @@ from torch.distributed.algorithms.quantization.quantization import DQuantType
 from torch.testing._internal.common_distributed import (
     MultiProcessTestCase,
     requires_gloo,
+    skip_if_rocm,
     skip_if_lt_x_gpu,
     requires_nccl,
 )
@@ -26,9 +27,9 @@ def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
     if value is None:
         value = size
     if device_id is None:
-        return torch.empty(size, size, size, dtype=dtype).fill_(value)
+        return torch.empty(size, dtype=dtype).fill_(value)
     else:
-        return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)
+        return torch.empty(size, dtype=dtype).fill_(value).cuda(device_id)
 if TEST_WITH_DEV_DBG_ASAN:
     print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
     sys.exit(0)
@@ -38,7 +39,6 @@ if NO_MULTIPROCESSING_SPAWN:
     sys.exit(0)
 
 BACKEND = os.environ["BACKEND"]
-
 if BACKEND == "gloo" or BACKEND == "nccl":
     class DistQuantizationTests(MultiProcessTestCase):
 
@@ -60,7 +60,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 
         @property
         def world_size(self):
-            return 2
+            return int(os.environ["WORLD_SIZE"])
 
         def _init_multigpu_helper(self):
             """Multigpu tests are designed to simulate the multi nodes with multi
@@ -69,7 +69,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
             divided to subsets, each process only uses a subset.
             """
             nGPUs = torch.cuda.device_count()
-            world_size = dist.get_world_size()
+            world_size = self.world_size
             visible_devices = range(nGPUs)
 
             if BACKEND == "nccl":
@@ -91,18 +91,29 @@ if BACKEND == "gloo" or BACKEND == "nccl":
         @requires_gloo()
         @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports all_gather_fp16")
         def test_all_gather_fp16(self):
-            store = dist.FileStore(self.file_name, int(self.world_size))
+            store = dist.FileStore(self.file_name, self.world_size)
             dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
             device = torch.device(f"cuda:{self.rank}")
             group = list(range(0, self.world_size))
             group_id = dist.group.WORLD
             self._test_all_gather(group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16)
 
+        @requires_gloo()
+        @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports all_gather_fp16")
+        def test_all_gather_bfp16(self):
+            store = dist.FileStore(self.file_name, self.world_size)
+            dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
+            device = torch.device(f"cuda:{self.rank}")
+            group = list(range(0, self.world_size))
+            group_id = dist.group.WORLD
+            self._test_all_gather(group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16)
+
         @requires_nccl()
         @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16")
         @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        @skip_if_rocm
         def test_all_to_all_fp16(self):
-            store = dist.FileStore(self.file_name, int(self.world_size))
+            store = dist.FileStore(self.file_name, self.world_size)
             dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
             device = torch.device(f"cuda:{self.rank}")
             group = list(range(0, self.world_size))
@@ -117,16 +128,34 @@ if BACKEND == "gloo" or BACKEND == "nccl":
                 dtype=torch.float32,
                 qtype=DQuantType.FP16)
 
+        @requires_nccl()
+        @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16")
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        @skip_if_rocm
+        def test_all_to_all_bfp16(self):
+            store = dist.FileStore(self.file_name, self.world_size)
+            dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
+            device = torch.device(f"cuda:{self.rank}")
+            group = list(range(0, self.world_size))
+            group_id = dist.new_group(range(self.world_size))
+            rank_to_GPU = self._init_multigpu_helper()
+            self._test_all_to_all(
+                group,
+                group_id,
+                self.rank,
+                cuda=True,
+                rank_to_GPU=rank_to_GPU,
+                dtype=torch.float32,
+                qtype=DQuantType.BFP16)
+
         def _test_all_gather(
                 self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None):
             for dest in group:
-                tensor = _build_tensor(dest + 1, rank, dtype=dtype)
-                tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
-                expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
-                if (qtype is not None):
-                    allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None)
-                else:
-                    allgather = dist.all_gather
+                tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype)
+                tensors = [_build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group]
+                expected_tensors = [
+                    _build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group
+                ]
                 if cuda:
                     tensor = tensor.cuda(rank_to_GPU[rank][0])
                     tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
@@ -134,6 +163,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
                     tensor_shapes = [torch.view_as_real(tensors[0]).shape]
                 else:
                     tensor_shapes = [tensors[0].shape]
+                allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None)
                 allgather(tensors, tensor, group=group_id, async_op=False)
 
                 for t1, t2 in zip(tensors, expected_tensors):
@@ -168,11 +198,8 @@ if BACKEND == "gloo" or BACKEND == "nccl":
                         t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
                     ]
                     out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
-                if(qtype is not None):
-                    quantize_alltoall = quant.auto_quantize(dist.all_to_all, qtype, quant_loss=None)
-                    quantize_alltoall(out_tensors, in_tensors, group=group_id)
-                else:
-                    dist.all_to_all(out_tensors, in_tensors, group=group_id)
+                quantize_alltoall = quant.auto_quantize(dist.all_to_all, qtype, quant_loss=None)
+                quantize_alltoall(out_tensors, in_tensors, group=group_id)
                 for t1, t2 in zip(out_tensors, expected_tensors):
                     self.assertEqual(t1, t2)
 
index 5f4cc0d..3f62253 100644 (file)
@@ -551,6 +551,7 @@ libtorch_cuda_distributed_extra_sources = [
     "torch/csrc/distributed/c10d/NCCLUtils.cpp",
     "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
     "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
+    "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
 ]
 
 libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + libtorch_cuda_distributed_extra_sources
@@ -737,6 +738,7 @@ libtorch_python_distributed_core_sources = [
     "torch/csrc/distributed/c10d/frontend.cpp",
     "torch/csrc/distributed/c10d/init.cpp",
     "torch/csrc/distributed/c10d/python_comm_hook.cpp",
+    "torch/csrc/distributed/c10d/quantization/quantization.cpp",
 ]
 
 libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
index 201f0c2..6b52d3c 100644 (file)
@@ -17,6 +17,7 @@
 
 #ifdef USE_C10D_NCCL
 #include <c10d/ProcessGroupNCCL.hpp>
+#include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
 #endif
 
 #ifdef USE_C10D_MPI
 #include <c10d/frontend.hpp>
 #include <c10d/logger.hpp>
 #include <c10d/reducer.hpp>
+
 #include <torch/csrc/Exceptions.h>
 #include <torch/csrc/distributed/c10d/python_comm_hook.h>
+#include <torch/csrc/distributed/c10d/quantization/quantization.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
 #include <torch/csrc/utils/object_ptr.h>
 #include <torch/csrc/utils/pybind.h>
@@ -1644,6 +1647,27 @@ PyMethodDef* python_functions() {
   return methods;
 }
 
+namespace quantization {
+TORCH_LIBRARY(q, m) {
+    m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
+    m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor");
+}
+    TORCH_LIBRARY_IMPL(q, CPU, m) {
+        m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu);
+        m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu);
+    }
+
+#ifdef USE_C10D_NCCL
+    #define DISPATCH_TO_CUDA(name, function) \
+        m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
+    TORCH_LIBRARY_IMPL(q, CUDA, m) {
+        DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
+        DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
+    }
+#endif
+
+} // namespace quantization
+
 } // namespace c10d
 } // namespace distributed
 } // namespace torch
diff --git a/torch/csrc/distributed/c10d/quantization/quantization.cpp b/torch/csrc/distributed/c10d/quantization/quantization.cpp
new file mode 100644 (file)
index 0000000..b9682d7
--- /dev/null
@@ -0,0 +1,93 @@
+#include <torch/csrc/distributed/c10d/quantization/quantization.h>
+#include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
+
+namespace torch {
+namespace distributed {
+namespace c10d {
+namespace quantization {
+
+void FloatToBFloat16Quantized_ref(
+    const float* const input,
+    const size_t nrows,
+    const size_t ncols,
+    uint16_t* const output){
+  for (const auto row : c10::irange(nrows)) {
+    const float* input_row = input + row * ncols;
+    uint16_t* output_row = output + row * ncols;
+
+    for (const auto col : c10::irange(ncols)) {
+      output_row[col] =
+          (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
+          16;
+    }
+  }
+}
+
+void BFloat16QuantizedToFloat_ref(
+    const at::BFloat16* const input,
+    const size_t nrows,
+    const size_t ncols,
+    float* const output){
+  const int32_t output_columns = ncols;
+
+  for (const auto row : c10::irange(nrows)) {
+    const at::BFloat16* input_row = input + row * ncols;
+    float* output_row = output + row * output_columns;
+
+    for (const auto col : c10::irange(ncols)) {
+      uint32_t val_fp32 = static_cast<uint32_t>(
+                              reinterpret_cast<const uint16_t*>(input_row)[col])
+          << 16;
+      reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
+    }
+  }
+}
+
+at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
+  TENSOR_ON_CPU(input);
+  // Currently it supports 2D inputs
+  TENSOR_NDIM_EQUALS(input, 2);
+
+  const auto input_sizes = input.sizes();
+  const int32_t nrows = input_sizes[0];
+  const int32_t ncols = input_sizes[1];
+  const int32_t output_columns = ncols;
+  auto output = at::empty(
+      {nrows, output_columns},
+      input.options().dtype(at::kHalf));
+
+  FloatToBFloat16Quantized_ref(
+      input.data_ptr<float>(),
+      nrows,
+      ncols,
+      reinterpret_cast<uint16_t*>(output.data_ptr<at::Half>()));
+
+  return output;
+}
+
+at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
+  TENSOR_ON_CPU(input);
+  // Currently it supports 2D inputs
+  TENSOR_NDIM_EQUALS(input, 2);
+
+  const auto input_sizes = input.sizes();
+  const int32_t nrows = input_sizes[0];
+  const int32_t ncols = input_sizes[1];
+  const int32_t output_columns = ncols;
+
+  auto output = at::empty(
+      {nrows, output_columns}, // 4 = sizeof(float)
+      input.options().dtype(at::kFloat)); //
+  BFloat16QuantizedToFloat_ref(
+      reinterpret_cast<at::BFloat16*>(input.data_ptr<at::Half>()),
+      nrows,
+      ncols,
+      output.data_ptr<float>());
+
+  return output;
+}
+
+} // namespace quantization
+} // namespace c10d
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/c10d/quantization/quantization.h b/torch/csrc/distributed/c10d/quantization/quantization.h
new file mode 100644 (file)
index 0000000..658fa75
--- /dev/null
@@ -0,0 +1,20 @@
+// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
+
+#pragma once
+
+
+#include <ATen/ATen.h>
+#include <vector>
+
+namespace torch {
+namespace distributed {
+namespace c10d {
+namespace quantization {
+
+at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input);
+at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input);
+
+} // namespace quantization
+} // namespace c10d
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu
new file mode 100644 (file)
index 0000000..5590e03
--- /dev/null
@@ -0,0 +1,148 @@
+#include <c10/cuda/CUDAGuard.h>
+#include <c10d/Utils.hpp>
+#include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
+#include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
+
+// FP32 -> BF16 kernel
+__global__ inline void _float_to_bfloat16_cuda_kernel(
+    const float* __restrict__ input,
+    const int nrows,
+    const int ncols,
+    uint16_t* __restrict__ output) {
+  const int row_incre = blockDim.y * gridDim.y;
+  const int col_incre = blockDim.x * gridDim.x;
+  for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
+       row += row_incre) {
+    const float* input_row = input + row * ncols;
+    uint16_t* output_row = output + row * ncols;
+    for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
+         col += col_incre) {
+      // Add 2^15 and right shift 16 to do round-nearest
+      output_row[col] =
+          (*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
+          16;
+    }
+  }
+}
+
+// BF16 -> FP32 kernel
+__global__ inline void _bfloat16_to_float_cuda_kernel(
+    const uint16_t* __restrict__ input,
+    const int nrows,
+    const int ncols,
+    float* __restrict__ output) {
+  const int row_incre = blockDim.y * gridDim.y;
+  const int col_incre = blockDim.x * gridDim.x;
+  for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
+       row += row_incre) {
+    for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
+         col += col_incre) {
+      const uint16_t* input_row = input + row * ncols;
+      float* output_row = output + row * ncols;
+      uint32_t val_fp32 = static_cast<uint32_t>(
+                              reinterpret_cast<const uint16_t*>(input_row)[col])
+          << 16;
+      reinterpret_cast<uint32_t*>(output_row)[col] = val_fp32;
+    }
+  }
+}
+
+namespace torch {
+namespace distributed {
+namespace c10d {
+namespace quantization {
+
+at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
+  TENSOR_ON_CUDA_GPU(input);
+  // Currently it supports 2D inputs
+  TENSOR_NDIM_EQUALS(input, 2);
+
+  at::cuda::OptionalCUDAGuard device_guard;
+  device_guard.set_index(input.get_device());
+
+  const int nrows = input.size(0);
+  const int ncols = input.size(1);
+  const int output_columns = ncols;
+
+  auto output = at::empty(
+      {nrows, output_columns},
+      input.options().dtype(at::kHalf)); // at::kHalf
+
+  if (nrows == 0 || output_columns == 0) {
+    return output;
+  }
+
+  // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
+  // NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16
+
+  constexpr int threads_per_block = 256;
+  const int blockDim_x = std::min(output_columns, threads_per_block);
+  dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
+  const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
+  const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u);
+  dim3 gridDim(gridDim_x, gridDim_y);
+
+  _float_to_bfloat16_cuda_kernel<<<
+      gridDim,
+      blockDim,
+      0,
+      at::cuda::getCurrentCUDAStream()>>>(
+      input.data_ptr<float>(),
+      nrows,
+      ncols,
+      // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
+      // NCCL
+      reinterpret_cast<uint16_t*>(output.data_ptr<at::Half>()));
+  //C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return output;
+}
+
+at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
+  TENSOR_ON_CUDA_GPU(input);
+  // Currently it supports 2D inputs
+  TENSOR_NDIM_EQUALS(input, 2);
+
+  at::cuda::OptionalCUDAGuard device_guard;
+  device_guard.set_index(input.get_device());
+
+  const int nrows = input.size(0);
+  const int ncols = input.size(1);
+  const int output_columns = ncols;
+
+  auto output = at::empty(
+      {nrows, output_columns}, // 4 = sizeof(float)
+      input.options().dtype(at::kFloat)); // at::kBytes for uint8_t
+
+  if (nrows == 0 || output_columns == 0) {
+    return output;
+  }
+
+  constexpr int threads_per_block = 256;
+
+  const int blockDim_x = std::min(output_columns, threads_per_block);
+  dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
+  const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
+  const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u);
+  dim3 gridDim(gridDim_x, gridDim_y);
+
+  _bfloat16_to_float_cuda_kernel<<<
+      gridDim,
+      blockDim,
+      0,
+      at::cuda::getCurrentCUDAStream()>>>(
+      // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
+      // NCCL
+      reinterpret_cast<uint16_t*>(input.data_ptr<at::Half>()),
+      nrows,
+      ncols,
+      output.data_ptr<float>());
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return output;
+}
+
+} // namespace quantization
+} // namespace c10d
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h
new file mode 100644 (file)
index 0000000..2a0c8f8
--- /dev/null
@@ -0,0 +1,20 @@
+// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
+
+#pragma once
+
+
+#include <ATen/ATen.h>
+#include <vector>
+
+namespace torch {
+namespace distributed {
+namespace c10d {
+namespace quantization {
+
+at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input);
+at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input);
+
+} // namespace quantization
+} // namespace c10d
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/c10d/quantization/quantization_utils.h b/torch/csrc/distributed/c10d/quantization/quantization_utils.h
new file mode 100644 (file)
index 0000000..0467ba2
--- /dev/null
@@ -0,0 +1,31 @@
+// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
+
+#pragma once
+
+#include <ATen/ATen.h>
+
+#include <typeinfo>
+
+inline std::string torch_tensor_device_name(const at::Tensor& ten) {
+  return c10::DeviceTypeName(ten.device().type());
+}
+
+#define TENSOR_NDIM_EQUALS(ten, dims)      \
+  TORCH_CHECK(                             \
+      (ten).ndimension() == (dims),        \
+      "Tensor '" #ten "' must have " #dims \
+      " dimension(s). "                    \
+      "Found ",                            \
+      (ten).ndimension())
+
+#define TENSOR_ON_CPU(x)                                      \
+  TORCH_CHECK(                                                \
+      !x.is_cuda(),                           \
+      #x " must be a CPU tensor; it is currently on device ", \
+      torch_tensor_device_name(x))
+
+#define TENSOR_ON_CUDA_GPU(x)                                  \
+  TORCH_CHECK(                                                 \
+      x.is_cuda(),                                             \
+      #x " must be a CUDA tensor; it is currently on device ", \
+      torch_tensor_device_name(x))
index 724d6aa..d58c58c 100644 (file)
@@ -10,7 +10,12 @@ TORCH_HALF_MIN = torch.finfo(torch.float16).min
 TORCH_HALF_MAX = torch.finfo(torch.float16).max
 
 class DQuantType(Enum):
-    FP16 = "fp16"
+    """
+    Different quantization methods for auto_quantize API are identified here.
+    auto_quantize API currently supports fp16 and bfp16 methods.
+    """
+    FP16 = "fp16",
+    BFP16 = "bfp16"
 
     def __str__(self) -> str:
         return self.value
@@ -26,6 +31,8 @@ def _quantize_tensor(tensor, qtype):
         )
     if (qtype == DQuantType.FP16):
         return _fp32_to_fp16_with_clamp(tensor)
+    elif (qtype == DQuantType.BFP16):
+        return torch.ops.q._FloatToBfloat16Quantized(tensor)
     else:
         raise RuntimeError(
             f'Quantization type {qtype} is not supported'
@@ -38,13 +45,8 @@ def _quantize_tensor_list(tensor_list, qtype):
         raise RuntimeError(
             f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
         )
-    if (qtype == DQuantType.FP16):
-        quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
-        return quantized_tensor_list
-    else:
-        raise RuntimeError(
-            f'Quantization type {qtype} is not supported'
-        )
+    quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
+    return quantized_tensor_list
 
 def _dequantize_tensor(tensor, qtype, quant_loss=None):
     if not isinstance(tensor, torch.Tensor):
@@ -60,6 +62,13 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None):
             return tensor.float()
         else:
             return tensor.float() / quant_loss
+    elif (qtype == DQuantType.BFP16):
+        if tensor.dtype != torch.float16:
+            raise RuntimeError(
+                f"tensor dtype is {tensor.dtype} while expected to be FP16."
+            )
+        else:
+            return torch.ops.q._Bfloat16QuantizedToFloat(tensor)
     else:
         raise RuntimeError(
             f'Quantization type {qtype} is not supported'
@@ -73,26 +82,26 @@ def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
         raise RuntimeError(
             f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
         )
-    elif (qtype == DQuantType.FP16):
-        dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
-        return dequantized_tensor_list
-    else:
-        raise RuntimeError(
-            f'Quantization type {qtype} is not supported'
-        )
+    dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
+    return dequantized_tensor_list
 
 
 def auto_quantize(func, qtype, quant_loss=None):
     """
     This is a prototype API that automatically quantize the input tensors, choose the precision types, and
     pass other necessary arguments and then dequantizes the output.
+
     Currently it only supports:
-        . FP16 quantization method
+        . FP16 and BFP16 quantization method supported for gloo and nccl backends
         . all_gather, all_to_all collective ops
+
+    Note: BFP16 only supports 2D tensors.
+
     Args:
         func (callable): A function representing collective operations.
         qtype (QuantType): Quantization method
         quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
+
     Returns:
         (callable): the same collective as func but enables automatic quantization/dequantization.
     """