simplify lambda function use in conv dnnlowp ops to fix #15911 (#15996)
authorJongsoo Park <jongsoo@fb.com>
Mon, 14 Jan 2019 07:30:09 +0000 (23:30 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 07:32:48 +0000 (23:32 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15996

As reported in issue #15911, gcc 4.9 was getting internal compiler error due to a complex use of lambda function in conv_dnnlowp_op.cc and conv_acc16_op.cc . This diff simplifies them.

Reviewed By: viswanathgs

Differential Revision: D13648264

fbshipit-source-id: 1551ae8a0a7653749185dca51ccceb2471b96b82

caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_pool_dnnlowp_op_base.h

index a32a7b0..b371337 100644 (file)
@@ -280,10 +280,10 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
   // height and width.
   const uint8_t* Xdata = X.template data<uint8_t>();
 
-  col_buffer_.Resize(buffer_shape);
-  uint8_t* col_buffer_data = col_buffer_.template mutable_data<uint8_t>();
+  auto f = [&](Tensor* col_buffer, vector<int32_t>* Y_int32) {
+    col_buffer->Resize(buffer_shape);
+    uint8_t* col_buffer_data = col_buffer->template mutable_data<uint8_t>();
 
-  auto f = [&](vector<int32_t>* Y_int32) {
     Y_int32->resize(M * output_image_size * dnnlowp_get_max_threads());
     vector<int> buffer_shape_per_thread(
         buffer_shape.begin() + 1, buffer_shape.end());
@@ -387,11 +387,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
     } // for each image_id
   }; // f
 
-  if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
-    this->RunWithSharedInt32Buffer_(f);
-  } else {
-    f(&(this->Y_int32_));
-  }
+  this->RunWithSharedBuffer_(&col_buffer_, &(this->Y_int32_), f);
 
   PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
 
@@ -637,7 +633,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
   // The col buffer is stored in HWC order as well - kernel_dim, and the height
   // and width.
 
-  auto f = [&](vector<int32_t>* Y_int32) {
+  auto f = [&](Tensor* col_buffer, vector<int32_t>* Y_int32) {
     Y_int32->resize(Y->numel());
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -647,119 +643,107 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
     bool no_im2col = this->NoIm2ColNHWC_();
 
     // Im2Col, followed by gemm.
-    auto f2 = [&](Tensor* col_buffer_) {
-      const uint8_t* Xdata = X.template data<uint8_t>();
-      const uint8_t* col_buffer_data =
-          no_im2col ? Xdata : this->Im2ColNHWC_(col_buffer_);
+    const uint8_t* Xdata = X.template data<uint8_t>();
+    const uint8_t* col_buffer_data =
+        no_im2col ? Xdata : this->Im2ColNHWC_(col_buffer);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      t_end = chrono::system_clock::now();
-      dt = chrono::duration<double>(t_end - t_begin).count();
-      LOG(INFO) << "this=" << this << " im2col: " << dt * 1e3 << " ms";
-      t_begin = chrono::system_clock::now();
+    t_end = chrono::system_clock::now();
+    dt = chrono::duration<double>(t_end - t_begin).count();
+    LOG(INFO) << "this=" << this << " im2col: " << dt * 1e3 << " ms";
+    t_begin = chrono::system_clock::now();
 #endif
 
-      using namespace fbgemm;
-      int row_offset_size_per_thread = -1;
-      int x_pack_buf_size_per_thread = -1;
-      if (Wq_acc16_packed_) {
-        row_offset_size_per_thread =
-            PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize();
-        x_pack_buf_size_per_thread =
-            PackAWithRowOffset<uint8_t, int16_t>::packedBufferSize();
-        row_offsets_.resize(
-            dnnlowp_get_max_threads() * row_offset_size_per_thread);
-        X_pack_buf_.resize(
-            dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
-      }
+    using namespace fbgemm;
+    int row_offset_size_per_thread = -1;
+    int x_pack_buf_size_per_thread = -1;
+    if (Wq_acc16_packed_) {
+      row_offset_size_per_thread =
+          PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize();
+      x_pack_buf_size_per_thread =
+          PackAWithRowOffset<uint8_t, int16_t>::packedBufferSize();
+      row_offsets_.resize(
+          dnnlowp_get_max_threads() * row_offset_size_per_thread);
+      X_pack_buf_.resize(
+          dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
+    }
 
-      uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
+    uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
 
-      // Main GEMM for non-outlier
-      if (Wq_acc16_packed_)
+    // Main GEMM for non-outlier
+    if (Wq_acc16_packed_)
 #ifdef _OPENMP
 #pragma omp parallel
 #endif
-      {
-        // fast path
-        int tid = dnnlowp_get_thread_num();
-
-        // no im2col fusion
-        PackAWithRowOffset<uint8_t, int16_t> packA(
-            matrix_op_t::NoTranspose,
-            N * output_image_size,
-            group_ * kernel_dim,
-            col_buffer_data,
-            group_ * kernel_dim,
-            X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
-            group_,
-            row_offsets_.data() + tid * row_offset_size_per_thread);
+    {
+      // fast path
+      int tid = dnnlowp_get_thread_num();
 
-        if (this->quantize_groupwise_) {
-          DispatchFBGEMM_<QuantizationGranularity::GROUP>(
-              packA, col_buffer_data, Y_int32, Y_uint8_data);
-        } else {
-          DispatchFBGEMM_<QuantizationGranularity::TENSOR>(
-              packA, col_buffer_data, Y_int32, Y_uint8_data);
-        }
+      // no im2col fusion
+      PackAWithRowOffset<uint8_t, int16_t> packA(
+          matrix_op_t::NoTranspose,
+          N * output_image_size,
+          group_ * kernel_dim,
+          col_buffer_data,
+          group_ * kernel_dim,
+          X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
+          group_,
+          row_offsets_.data() + tid * row_offset_size_per_thread);
+
+      if (this->quantize_groupwise_) {
+        DispatchFBGEMM_<QuantizationGranularity::GROUP>(
+            packA, col_buffer_data, Y_int32, Y_uint8_data);
       } else {
-        // slow path
-        conv_nhwc_acc16_ref_(
-            group_,
-            N,
-            output_image_size,
-            M,
-            kernel_dim,
-            col_buffer_data,
-            W_quantized_.data(),
-            Y_int32->data()
+        DispatchFBGEMM_<QuantizationGranularity::TENSOR>(
+            packA, col_buffer_data, Y_int32, Y_uint8_data);
+      }
+    } else {
+      // slow path
+      conv_nhwc_acc16_ref_(
+          group_,
+          N,
+          output_image_size,
+          M,
+          kernel_dim,
+          col_buffer_data,
+          W_quantized_.data(),
+          Y_int32->data()
 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH
-                ,
-            this
+              ,
+          this
 #endif
-        );
-      } // slow path
+      );
+    } // slow path
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      t_end = chrono::system_clock::now();
-      dt = chrono::duration<double>(t_end - t_begin).count();
-      double ops = 2. * N * output_image_size * M * kernel_dim;
-      double gops = ops / dt / 1e9;
-      LOG(INFO) << "this=" << this << " GEMM: " << dt * 1e3 << " ms " << gops
-                << " gops";
-      t_begin = chrono::system_clock::now();
+    t_end = chrono::system_clock::now();
+    dt = chrono::duration<double>(t_end - t_begin).count();
+    double ops = 2. * N * output_image_size * M * kernel_dim;
+    double gops = ops / dt / 1e9;
+    LOG(INFO) << "this=" << this << " GEMM: " << dt * 1e3 << " ms " << gops
+              << " gops";
+    t_begin = chrono::system_clock::now();
 #endif
 
-      if (!Wq_acc16_packed_) {
-        ConvOutlier_(col_buffer_data, Y_int32);
-      }
+    if (!Wq_acc16_packed_) {
+      ConvOutlier_(col_buffer_data, Y_int32);
+    }
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      t_end = chrono::system_clock::now();
-      dt = chrono::duration<double>(t_end - t_begin).count();
-      LOG(INFO) << "this=" << this << " out-lier: " << dt * 1e3 << " ms";
-      t_begin = chrono::system_clock::now();
+    t_end = chrono::system_clock::now();
+    dt = chrono::duration<double>(t_end - t_begin).count();
+    LOG(INFO) << "this=" << this << " out-lier: " << dt * 1e3 << " ms";
+    t_begin = chrono::system_clock::now();
 #endif
 
-      if (!Wq_acc16_packed_) {
-        this->RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
-      } else {
-        PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-      }
-    }; // f2
-
-    if (FLAGS_caffe2_force_shared_col_buffer || this->shared_buffer_) {
-      runWithSharedBuffer<CPUContext>(this->ws_, f2);
+    if (!Wq_acc16_packed_) {
+      this->RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
     } else {
-      f2(&(this->col_buffer_));
+      PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
     }
   }; // f
 
-  if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
-    this->RunWithSharedInt32Buffer_(f);
-  } else {
-    f(&(this->Y_int32_));
-  }
+  this->RunWithSharedBuffer_(&col_buffer_, &(this->Y_int32_), f);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
   t_end = chrono::system_clock::now();
index 4192440..7a18c7a 100644 (file)
@@ -563,103 +563,91 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
   T* Y_data_T = Y->template mutable_data<T>();
   column_offsets_->resize(Y_HxW * dnnlowp_get_max_threads());
 
-  auto f = [&](Tensor* col_buffer) {
+  auto f = [&](Tensor* col_buffer, vector<int32_t>* Y_int32) {
     col_buffer->Resize(buffer_shape);
     vector<int> buffer_shape_per_thread(
         buffer_shape.begin() + 1, buffer_shape.end());
     T* col_buffer_data = col_buffer->template mutable_data<T>();
 
-    auto f2 = [&](vector<int32_t>* Y_int32) {
-      Y_int32->resize(M * Y_HxW * dnnlowp_get_max_threads());
+    Y_int32->resize(M * Y_HxW * dnnlowp_get_max_threads());
 
-      // Im2Col, followed by gemm.
+    // Im2Col, followed by gemm.
 #ifdef _OPENMP
 #pragma omp parallel for
 #endif
-      for (int image_id = 0; image_id < N; ++image_id) {
-        int tid = dnnlowp_get_thread_num();
-        for (int group_id = 0; group_id < group_; ++group_id) {
-          if (this->kernel_.size() == 2) {
-            math::Im2ColNCHW<T>(
-                C / group_,
-                input_dims[0],
-                input_dims[1],
-                kernel_h(),
-                kernel_w(),
-                dilation_h(),
-                dilation_w(),
-                pad_t(),
-                pad_l(),
-                pad_b(),
-                pad_r(),
-                stride_h(),
-                stride_w(),
-                Xdata + (group_ * image_id + group_id) * input_offset,
-                col_buffer_data + tid * col_buffer_size,
-                &context_,
-                in_qparams_[INPUT].zero_point);
-          } else {
-            math::Im2ColNdNCHW<T>(
-                this->kernel_.size(),
-                C * X_HxW,
-                col_buffer_size,
-                img_shape.data(),
-                buffer_shape_per_thread.data(),
-                this->kernel_.data(),
-                this->stride_.data(),
-                this->dilation_.data(),
-                this->pads_.data(),
-                Xdata + (group_ * image_id + group_id) * input_offset,
-                col_buffer_data + tid * col_buffer_size,
-                &context_,
-                in_qparams_[INPUT].zero_point);
-          }
-
-          // quantize col_buffer
-          T* col_buffer_private = col_buffer_data + tid * col_buffer_size;
-
-          int32_t* Y_int32_temp =
-              Y_int32->data() + ((M / group_) * group_id + M * tid) * Y_HxW;
-          T_signed* W_quantized_group =
-              W_quantized_.data() + (M / group_) * group_id * kernel_dim;
-
-          for (int i = 0; i < M / group_; ++i) {
-            for (int j = 0; j < Y_HxW; ++j) {
-              int32_t sum = 0;
-              for (int k = 0; k < kernel_dim; ++k) {
-                int w = W_quantized_group[i * kernel_dim + k];
-                int x = col_buffer_private[k * Y_HxW + j];
-                sum += w * x;
-              }
-              Y_int32_temp[i * Y_HxW + j] = sum;
-            } // j
-          } // i
-
-          RunOnDeviceEpilogueNCHW_(
-              col_buffer_private,
-              Y_int32_temp,
-              Y_data_T + (M * image_id + M / group_ * group_id) * Y_HxW,
-              M / group_ * group_id,
-              group_id);
-        } // for each group
-      } // for each image_id
+    for (int image_id = 0; image_id < N; ++image_id) {
+      int tid = dnnlowp_get_thread_num();
+      for (int group_id = 0; group_id < group_; ++group_id) {
+        if (this->kernel_.size() == 2) {
+          math::Im2ColNCHW<T>(
+              C / group_,
+              input_dims[0],
+              input_dims[1],
+              kernel_h(),
+              kernel_w(),
+              dilation_h(),
+              dilation_w(),
+              pad_t(),
+              pad_l(),
+              pad_b(),
+              pad_r(),
+              stride_h(),
+              stride_w(),
+              Xdata + (group_ * image_id + group_id) * input_offset,
+              col_buffer_data + tid * col_buffer_size,
+              &context_,
+              in_qparams_[INPUT].zero_point);
+        } else {
+          math::Im2ColNdNCHW<T>(
+              this->kernel_.size(),
+              C * X_HxW,
+              col_buffer_size,
+              img_shape.data(),
+              buffer_shape_per_thread.data(),
+              this->kernel_.data(),
+              this->stride_.data(),
+              this->dilation_.data(),
+              this->pads_.data(),
+              Xdata + (group_ * image_id + group_id) * input_offset,
+              col_buffer_data + tid * col_buffer_size,
+              &context_,
+              in_qparams_[INPUT].zero_point);
+        }
 
-      PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-      MeasureQuantizationError_();
-    }; // f2
+        // quantize col_buffer
+        T* col_buffer_private = col_buffer_data + tid * col_buffer_size;
+
+        int32_t* Y_int32_temp =
+            Y_int32->data() + ((M / group_) * group_id + M * tid) * Y_HxW;
+        T_signed* W_quantized_group =
+            W_quantized_.data() + (M / group_) * group_id * kernel_dim;
+
+        for (int i = 0; i < M / group_; ++i) {
+          for (int j = 0; j < Y_HxW; ++j) {
+            int32_t sum = 0;
+            for (int k = 0; k < kernel_dim; ++k) {
+              int w = W_quantized_group[i * kernel_dim + k];
+              int x = col_buffer_private[k * Y_HxW + j];
+              sum += w * x;
+            }
+            Y_int32_temp[i * Y_HxW + j] = sum;
+          } // j
+        } // i
+
+        RunOnDeviceEpilogueNCHW_(
+            col_buffer_private,
+            Y_int32_temp,
+            Y_data_T + (M * image_id + M / group_ * group_id) * Y_HxW,
+            M / group_ * group_id,
+            group_id);
+      } // for each group
+    } // for each image_id
 
-    if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
-      this->RunWithSharedInt32Buffer_(f2);
-    } else {
-      f2(&Y_int32_);
-    }
+    PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
+    MeasureQuantizationError_();
   }; // f
 
-  if (FLAGS_caffe2_force_shared_col_buffer || this->shared_buffer_) {
-    runWithSharedBuffer<CPUContext>(this->ws_, f);
-  } else {
-    f(&col_buffer_);
-  }
+  this->RunWithSharedBuffer_(&col_buffer_, &Y_int32_, f);
 
   return true;
 } // RunOnDeviceWithOrderNCHW
@@ -1324,68 +1312,55 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
 #endif
 
   bool no_im2col = NoIm2ColNHWC_();
-  auto f = [&](vector<int32_t>* Y_int32) {
+  auto f = [&](Tensor* col_buffer, vector<int32_t>* Y_int32) {
     if (!TakeDepthWise3x3FastPath_() && !TakeDepthWise3x3x3FastPath_()) {
       Y_int32->resize(Y->numel());
     }
 
     // Im2col, followed by gemm.
-    auto f2 = [&](Tensor* col_buffer_) {
-      const T* Xdata = X.template data<T>();
-      const T* col_buffer_data = no_im2col ? Xdata : Im2ColNHWC_(col_buffer_);
+    const T* Xdata = X.template data<T>();
+    const T* col_buffer_data = no_im2col ? Xdata : Im2ColNHWC_(col_buffer);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      /*if (VLOG_IS_ON(3)) */ {
-        t_end = chrono::system_clock::now();
-        double dt = chrono::duration<double>(t_end - t_begin).count();
-        LOG(INFO) << "this=" << this << " im2col: " << dt * 1e3 << " ms";
-        t_begin = chrono::system_clock::now();
-      }
+    /*if (VLOG_IS_ON(3)) */ {
+      t_end = chrono::system_clock::now();
+      double dt = chrono::duration<double>(t_end - t_begin).count();
+      LOG(INFO) << "this=" << this << " im2col: " << dt * 1e3 << " ms";
+      t_begin = chrono::system_clock::now();
+    }
 #endif
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      /*if (VLOG_IS_ON(3)) */ {
-        t_end = chrono::system_clock::now();
-        double dt = chrono::duration<double>(t_end - t_begin).count();
-        LOG(INFO) << "this=" << this << " quantize col_buf: " << dt * 1e3
-                  << " ms";
-        t_begin = chrono::system_clock::now();
-      }
+    /*if (VLOG_IS_ON(3)) */ {
+      t_end = chrono::system_clock::now();
+      double dt = chrono::duration<double>(t_end - t_begin).count();
+      LOG(INFO) << "this=" << this << " quantize col_buf: " << dt * 1e3
+                << " ms";
+      t_begin = chrono::system_clock::now();
+    }
 #endif
 
-      ConvNHWCCore_(col_buffer_data, Y_int32);
+    ConvNHWCCore_(col_buffer_data, Y_int32);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      /*if (VLOG_IS_ON(3)) */ {
-        t_end = chrono::system_clock::now();
-        double dt = chrono::duration<double>(t_end - t_begin).count();
-        LOG(INFO) << "this=" << this << " GEMM: " << dt * 1e3 << " ms";
-        t_begin = chrono::system_clock::now();
-      }
+    /*if (VLOG_IS_ON(3)) */ {
+      t_end = chrono::system_clock::now();
+      double dt = chrono::duration<double>(t_end - t_begin).count();
+      LOG(INFO) << "this=" << this << " GEMM: " << dt * 1e3 << " ms";
+      t_begin = chrono::system_clock::now();
+    }
 #endif
 
-      if (Wq_packed_ || Wq_depthwise_3x3_packed_ ||
-          Wq_depthwise_3x3x3_packed_) {
-        // In fast path with fbgemm except when
-        // rescaling quantized numbers should've been already done.
-        PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-      } else {
-        RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
-      }
-    }; // f2
-
-    if (FLAGS_caffe2_force_shared_col_buffer || this->shared_buffer_) {
-      runWithSharedBuffer<CPUContext>(this->ws_, f2);
+    if (Wq_packed_ || Wq_depthwise_3x3_packed_ || Wq_depthwise_3x3x3_packed_) {
+      // In fast path with fbgemm except when
+      // rescaling quantized numbers should've been already done.
+      PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
     } else {
-      f2(&col_buffer_);
+      RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
     }
   }; // f
 
-  if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
-    this->RunWithSharedInt32Buffer_(f);
-  } else {
-    f(&Y_int32_);
-  }
+  this->RunWithSharedBuffer_(&col_buffer_, &Y_int32_, f);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
   /*if (VLOG_IS_ON(3)) */ {
index 10ee04a..ef2df60 100644 (file)
@@ -5,6 +5,7 @@
 #endif
 
 #include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_op_shared.h"
 #include "caffe2/operators/conv_pool_op_base.h"
 #include "caffe2/quantization/server/fbgemm_pack_blob.h"
 #include "caffe2/quantization/server/op_wrapper.h"
@@ -12,6 +13,8 @@
 #ifdef _OPENMP
 C10_DECLARE_int(caffe2_omp_num_threads);
 #endif
+C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
+C10_DECLARE_bool(caffe2_force_shared_col_buffer);
 
 namespace caffe2 {
 
@@ -157,18 +160,34 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
     ws_->CreateBlob("__CAFFE2_DNNLOWP_SHARED_INT32_BUFFER_CPU__");
   }
 
-  void RunWithSharedInt32Buffer_(
-      std::function<void(vector<int32_t>* Y_int32)> f) {
-    auto* mutexBlob =
-        ws_->GetBlob("__CAFFE2_DNNLOWP_SHARED_INT32_BUFFER_CPU_MUTEX__");
-    CAFFE_ENFORCE(mutexBlob, "Must call CreateSharedInt32Buffer() first");
-
-    auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
-    std::lock_guard<std::mutex> g(**mutexPtr);
+  void RunWithSharedBuffer_(
+      Tensor* col_buffer,
+      vector<int32_t>* Y_int32,
+      std::function<
+          void(Tensor* col_buffer_shared, vector<int32_t>* Y_int32_shared)> f) {
+    auto f2 = [this, Y_int32, f](Tensor* col_buffer_shared) {
+      if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
+        auto* mutexBlob =
+            ws_->GetBlob("__CAFFE2_DNNLOWP_SHARED_INT32_BUFFER_CPU_MUTEX__");
+        CAFFE_ENFORCE(mutexBlob, "Must call CreateSharedInt32Buffer() first");
+
+        auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
+        std::lock_guard<std::mutex> g(**mutexPtr);
+
+        auto* Y_int32_shared =
+            ws_->GetBlob("__CAFFE2_DNNLOWP_SHARED_INT32_BUFFER_CPU__")
+                ->template GetMutable<vector<int32_t>>();
+        f(col_buffer_shared, Y_int32_shared);
+      } else {
+        f(col_buffer_shared, Y_int32);
+      }
+    };
 
-    auto* Y_int32 = ws_->GetBlob("__CAFFE2_DNNLOWP_SHARED_INT32_BUFFER_CPU__")
-                        ->template GetMutable<vector<int32_t>>();
-    f(Y_int32);
+    if (FLAGS_caffe2_force_shared_col_buffer || this->shared_buffer_) {
+      runWithSharedBuffer<CPUContext>(this->ws_, f2);
+    } else {
+      f2(col_buffer);
+    }
   }
 
   bool measure_quantization_error_{false};