Change ConvPoolOp<Context>::SetOutputSize to ConvPoolOp<Context>::GetOutputSize ...
authorJerry Zhang <jerryzh@fb.com>
Fri, 8 Mar 2019 02:31:33 +0000 (18:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 8 Mar 2019 02:38:53 +0000 (18:38 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17764

Original commit changeset: f1923fdca4a1

reverted int8 ops fixes the original runtime regression.
We'll ignore the memory regression since it is flaky, see D14228484

Reviewed By: dzhulgakov

Differential Revision: D13885233

fbshipit-source-id: ccbe4b94acb44b7b4cb3ae4d73e3f6091e1e1195

16 files changed:
caffe2/cuda_rtc/pool_op_rtc_gpu.cc
caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm
caffe2/operators/conv_op_cudnn.cc
caffe2/operators/conv_pool_op_base.h
caffe2/operators/depthwise_3x3_conv_op_cudnn.cu
caffe2/operators/hip/conv_op_miopen.hip
caffe2/operators/hip/pool_op_miopen.hip
caffe2/operators/max_pool_with_index.cu
caffe2/operators/pad_op_gpu.cu
caffe2/operators/pool_op_cudnn.cc
caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_pool_dnnlowp_op_base.h
caffe2/quantization/server/dnnlowp_op.h
caffe2/quantization/server/pool_dnnlowp_op.cc
caffe2/share/contrib/depthwise/depthwise3x3_conv_op.cc

index f7e4e9d..5b45581 100644 (file)
@@ -196,8 +196,8 @@ class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
 
   bool RunOnDeviceWithOrderNCHW() override {
     auto& X = Input(0);
-    auto* Y = Output(0);
-    ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1));
+    auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
+    auto* Y = Output(0, output_sizes, at::dtype<float>());
 
     if (input_dims_ != X.sizes()) {
       // recompile
index 7494fe1..f556e9c 100644 (file)
@@ -257,11 +257,10 @@ void computeOutputHW(
     int* OH,
     int* OW) {
   Tensor input = caffe2::empty({1, 1, H, W}, at::dtype<float>().device(CPU));
-  Tensor output(CPU);
-  op->SetOutputSize(input, &output, 1);
-  CAFFE_ENFORCE_EQ(output.dim(), 4);
-  *OH = output.size(2);
-  *OW = output.size(3);
+  auto sizes = op->GetOutputSize(input, 1);
+  CAFFE_ENFORCE_EQ(sizes.size(), 4);
+  *OH = sizes[2];
+  *OW = sizes[3];
 }
 
 constexpr int computeMPSAlignOffset(int kernel, int pad) {
index c48c943..8c510f7 100644 (file)
@@ -514,13 +514,13 @@ template <typename T_X, typename T_W, typename T_B, typename T_Y>
 bool CudnnConvOp::DoRunWithType() {
   auto& X = Input(INPUT);
   auto& filter = Input(FILTER);
-  auto* Y = Output(0);
 
   // Figure out the output shape
   CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
   CAFFE_ENFORCE(filter.dim() >= 3 && filter.dim() <= 5);
   const int M = filter.dim32(0);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, M);
+  auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, M);
+  auto* Y = Output(0, output_sizes, at::dtype<T_Y>());
 
   int N = 0, C = 0, H = 0, W = 0, D = 0, H_out = 0, W_out = 0, D_out = 0;
   int group_offset_X = 0, group_offset_Y = 0;
index e403bfa..58ca094 100644 (file)
@@ -208,7 +208,7 @@ class ConvPoolOpBase : public Operator<Context> {
     return size;
   }
 
-  // Sets the output size. The output channel is manually provided since
+  // Gets the output size. The output channel is manually provided since
   // it may not be identical to the input channels.
   // This function can be used in the forward functions to obtain the output
   // sizes.
@@ -216,7 +216,25 @@ class ConvPoolOpBase : public Operator<Context> {
   // implementations that do not use first-class Tensor objects, such as the
   // MKL operator. One can still call this function with dummy
   // Tensor objects in order to obtain the sizes.
-  // TODO: passing sizes directly rather than Tensor
+  std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
+    CAFFE_ENFORCE_GE(input.dim(), 2);
+    const int inner_size = input.size_from_dim(1);
+    CAFFE_ENFORCE_GT(inner_size, 0);
+    std::vector<int64_t> output_dims;
+    InferOutputSize64(
+        input.sizes(),
+        output_channel,
+        order_,
+        global_pooling_,
+        legacy_pad_,
+        dilation_,
+        stride_,
+        &kernel_,
+        &pads_,
+        &output_dims);
+    return output_dims;
+  }
+
   void SetOutputSize(const Tensor& input, Tensor* output, int output_channel) {
     const int inner_size = input.size_from_dim(1);
     CAFFE_ENFORCE_GT(inner_size, 0);
@@ -276,6 +294,45 @@ class ConvPoolOpBase : public Operator<Context> {
     }
   }
 
+  static void InferOutputSize64(
+      const at::IntList& input_dims,
+      const int output_channel,
+      const StorageOrder order,
+      const bool global_pooling,
+      const LegacyPadding legacy_pad,
+      const std::vector<int>& dilation,
+      const std::vector<int>& stride,
+      std::vector<int>* kernel,
+      std::vector<int>* pads,
+      std::vector<int64_t>* output_dims) {
+    CAFFE_ENFORCE_NE(order, StorageOrder::UNKNOWN);
+    const int ndim = input_dims.size() - 2;
+    output_dims->resize(ndim + 2);
+    output_dims->front() = input_dims.front();
+    if (order == StorageOrder::NCHW) {
+      output_dims->at(1) = output_channel;
+    } else {
+      output_dims->back() = output_channel;
+    }
+    const int offset = order == StorageOrder::NCHW ? 2 : 1;
+    if (global_pooling) {
+      std::copy_n(input_dims.cbegin() + offset, ndim, kernel->begin());
+      std::fill_n(output_dims->begin() + offset, ndim, 1LL);
+    } else {
+      for (int i = 0; i < ndim; ++i) {
+        ComputeSizeAndPad64(
+            input_dims[i + offset],
+            stride[i],
+            kernel->at(i),
+            dilation[i],
+            legacy_pad,
+            &pads->at(i),
+            &pads->at(i + ndim),
+            &output_dims->at(i + offset));
+      }
+    }
+  }
+
   // ComputePads could be used in backward functions to figure out the padding
   // values for the given input.
   void ComputePads(const vector<int>& dims) {
@@ -670,6 +727,85 @@ class ConvPoolOpBase : public Operator<Context> {
     }
   }
 
+  static inline void ComputeSizeAndPad64(
+      const int in_size,
+      const int stride,
+      const int kernel,
+      const int dilation,
+      LegacyPadding legacy_pad,
+      int* pad_head,
+      int* pad_tail,
+      int64_t* out_size) {
+    const int dkernel = dilation * (kernel - 1) + 1;
+    switch (legacy_pad) {
+      case LegacyPadding::NOTSET:
+        // We will just use the direct padding head and tail values, but we
+        // will verify that they are non-negative.
+        CAFFE_ENFORCE_GE(in_size + *pad_head + *pad_tail, dkernel);
+        *out_size = static_cast<int>(
+            static_cast<float>(in_size + *pad_head + *pad_tail - dkernel) /
+                stride +
+            1);
+        break;
+      case LegacyPadding::VALID:
+        *pad_head = 0;
+        *pad_tail = 0;
+        *out_size = (in_size - dkernel) / stride + 1;
+        break;
+      case LegacyPadding::SAME: {
+        CAFFE_ENFORCE(
+            1 == dilation, "Dilation not supported for legacy padding.");
+        int legacy_target_size = (in_size + stride - 1) / stride;
+        int pad_needed = (legacy_target_size - 1) * stride + kernel - in_size;
+        if (CAFFE2_PAD_HEAD_MORE) {
+          *pad_head = (pad_needed + 1) / 2;
+        } else {
+          *pad_head = pad_needed / 2;
+        }
+        *pad_tail = pad_needed - *pad_head;
+        *out_size = (in_size + pad_needed - dkernel) / stride + 1;
+        break;
+      }
+      case LegacyPadding::CAFFE_LEGACY_POOLING:
+        // This is in order to adapt Caffe's pooling padding case. In this case,
+        // we will only use pad_head and will compute pad_tail to match the
+        // old caffe pooling strategy. Also see caffe2_legacy.proto for more
+        // details.
+        CAFFE_ENFORCE_GE(*pad_head, 0);
+        // Here, notice that caffe casts UP while caffe2 casts DOWN for the
+        // output size computation.
+        *out_size = std::ceil(
+            static_cast<float>(in_size + *pad_head * 2 - kernel) / stride + 1);
+        // If we have padding, caffe also ensures that the last pooling starts
+        // strictly inside the image (instead of at the padding); otherwise clip
+        // the last.
+        if (*pad_head > 0 && (*out_size - 1) * stride >= in_size + *pad_head) {
+          --*out_size;
+        }
+        // Now, compare the output size with the standard Caffe2 output size.
+        // The
+        // caffe2 standard output size should always be no larger than the
+        // output
+        // size of caffe.
+        int standard_out_size = static_cast<int>(
+            static_cast<float>(in_size + *pad_head * 2 - kernel) / stride + 1);
+        CAFFE_ENFORCE_GE(
+            *out_size,
+            standard_out_size,
+            "This should never happen. If this happens, double check the logic "
+            "above.");
+        if (*out_size > standard_out_size) {
+          LOG(WARNING)
+              << "You are hitting a case where Caffe's legacy padding calculation "
+                 "is hit. This leads to inefficient and sometimes incorrect "
+                 "results. We are keeping this behavior for backward compatibility"
+                 ", but you are strongly recommended to move away from it.";
+        }
+        *pad_tail = *pad_head + stride * (*out_size - standard_out_size);
+        break;
+    }
+  }
+
   // Accessors for 2D conv params.
 
   inline int pad_t() const {
index 36b54b9..5b76149 100644 (file)
@@ -288,7 +288,6 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CUDAContext> {
   bool RunOnDeviceWithOrderNCHW() override {
     const Tensor& X = Input(0);
     auto& filter = Input(1);
-    Tensor* Y = Output(0);
     const int N = X.dim32(0), C = X.dim32(1);
     CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
     const int M = filter.dim32(0);
@@ -300,7 +299,8 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CUDAContext> {
     CAFFE_ENFORCE_EQ(this->kernel_w(), 3);
     CAFFE_ENFORCE_EQ(this->kernel_h(), 3);
     CAFFE_ENFORCE_EQ(this->stride_h(), this->stride_w());
-    ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, filter.dim32(0));
+    auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, filter.dim32(0));
+    Tensor* Y = Output(0, sizes, at::dtype<float>());
     DepthwiseArgs args;
     args.batch = X.dim32(0);
     args.in_rows = X.dim32(2);
@@ -458,7 +458,7 @@ class Depthwise3x3ConvGradientOp final : public ConvPoolOpBase<CUDAContext> {
           M,
           dY.dim32(2),
           dY.dim32(3)));
-      
+
       auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<float>());
       CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
           cudnn_wrapper_.inline_cudnn_handle(),
index d083611..1bdca48 100644 (file)
@@ -207,7 +207,6 @@ template <typename T_X, typename T_W, typename T_B, typename MATH, typename T_Y>
 bool MIOPENConvOp::DoRunWithType() {
   auto& X = Input(INPUT);
   auto& Weight = Input(FILTER);
-  auto* Y = Output(0);
 
   // Figure out the output shape
   CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5);
@@ -216,7 +215,8 @@ bool MIOPENConvOp::DoRunWithType() {
       "Conv op with MIOpen engine is supported only for 2D convolutions");
 
   const int M = Weight.dim32(0);
-  ConvPoolOpBase<HIPContext>::SetOutputSize(X, Y, M);
+  auto sizes = ConvPoolOpBase<HIPContext>::GetOutputSize(X, M);
+  auto* Y = Output(0, sizes, at::dtype<T_Y>());
 
   int N = X.dim32(0);
   int C = X.dim32(1);
index 614b6cf..c1d5ee3 100644 (file)
@@ -61,7 +61,6 @@ class MIOPENPoolOp : public ConvPoolOpBase<HIPContext> {
   template <typename T, typename M>
   bool DoRunWithType() {
     auto& X = Input(0);
-    auto* Y = Output(0);
     int N = 0, C = 0, H = 0, W = 0, D = 0;
     int N_out = 0, C_out = 0, H_out = 0, W_out = 0;
     CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
@@ -69,7 +68,8 @@ class MIOPENPoolOp : public ConvPoolOpBase<HIPContext> {
     C = X.dim32(1);
     H = X.dim32(2);
     W = X.ndim() > 3 ? X.dim32(3) : 1;
-    ConvPoolOpBase::SetOutputSize(X, Y, C);
+    auto sizes = ConvPoolOpBase::GetOutputSize(X, C);
+    auto* Y = Output(0, sizes, at::dtype<T>());
 
     N_out = Y->dim32(0);
     C_out = Y->dim32(1);
index 31513b5..cefa831 100644 (file)
@@ -108,9 +108,10 @@ __global__ void MaxPoolBackward(
 template <typename T>
 bool MaxPoolWithIndexOp::DoRunWithType() {
   auto& X = Input(0);
-  auto* Y = Output(0);
 
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
+  auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
+  auto* Y = Output(0, sizes, at::dtype<T>());
+
   int output_size = Y->numel();
   auto* mask = Output(1, {output_size}, at::dtype<int>());
 
index c623633..f9c37af 100644 (file)
@@ -251,12 +251,13 @@ __global__ void PadImageGradientEdgeNHWC(
 template <>
 bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
   auto& X = Input(0);
-  auto* Y = Output(0);
   const int num = X.dim32(0);
   const int channels = X.dim32(1);
   const int height = X.dim32(2);
   const int width = X.dim32(3);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, channels);
+  auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, channels);
+  auto* Y = Output(0, sizes, at::dtype<float>());
+
   const int output_size = Y->numel();
   const int padded_height = Y->dim32(2);
   const int padded_width = Y->dim32(3);
@@ -327,12 +328,13 @@ bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
 template<>
 bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
   auto& X = Input(0);
-  auto* Y = Output(0);
   const int num = X.dim32(0);
   const int height = X.dim32(1);
   const int width = X.dim32(2);
   const int channels = X.dim32(3);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, channels);
+  auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, channels);
+  auto* Y = Output(0, sizes, at::dtype<float>());
+
   const int output_size = Y->numel();
   const int padded_height = Y->dim32(1);
   const int padded_width = Y->dim32(2);
@@ -403,7 +405,7 @@ bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
 template<>
 bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
   auto& dY = Input(0);
-  
+
   auto* dX = Output(0, { dY.dim32(0),
       dY.dim32(1),
       dY.dim32(2) - pad_t() - pad_b(),
@@ -483,7 +485,7 @@ bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
 template<>
 bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
   auto& dY = Input(0);
-  
+
   auto* dX = Output(0, { dY.dim32(0),
       dY.dim32(1) - pad_t() - pad_b(),
       dY.dim32(2) - pad_l() - pad_r(),
index 0e1160a..e656801 100644 (file)
@@ -100,11 +100,11 @@ class CuDNNPoolOp final : public ConvPoolOpBase<CUDAContext> {
   template <typename T>
   bool DoRunWithType() {
     const auto& X = Input(0);
-    auto* Y = Output(0);
     const int ndim = X.dim();
     const int N = X.dim32(0);
     const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
-    ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, C);
+    auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, C);
+    auto* Y = Output(0, sizes, at::dtype<T>());
     const T* X_data = X.template data<T>();
     T* Y_data = Y->template mutable_data<T>();
 
index b371337..b339e52 100644 (file)
@@ -102,8 +102,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
     const Tensor& X = InputTensorCPU_(INPUT);
     int N = X.dim32(0);
 
-    Tensor* Y = OutputTensorCPU_(0);
-    this->SetOutputSize(X, Y, filter.dim32(0));
+    auto sizes = this->GetOutputSize(X, filter.dim32(0));
+    Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
     const int output_image_size = this->GetDimsSize(*Y);
 
     if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
@@ -228,7 +228,6 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
 
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
-  Tensor* Y = OutputTensorCPU_(0);
   const int N = X.dim32(0), C = X.dim32(1);
   CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
   const int M = filter.dim32(0);
@@ -246,7 +245,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
       0,
       "The number of output channels is not divisible by group.");
 
-  this->SetOutputSize(X, Y, filter.dim32(0));
+  auto sizes = this->GetOutputSize(X, filter.dim32(0));
+  Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
 
   const vector<int> input_dims = GetDims(X);
   const vector<int> output_dims = GetDims(*Y);
@@ -618,14 +618,14 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
 
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
-  Tensor* Y = OutputTensorCPU_(0);
   const int N = X.dim32(0), C = X.dim32(X.ndim() - 1);
 
   CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
   const int M = filter.dim32(0);
   CAFFE_ENFORCE_EQ(filter.dim32(filter.ndim() - 1), C / group_);
 
-  this->SetOutputSize(X, Y, filter.dim32(0));
+  auto sizes = this->GetOutputSize(X, filter.dim32(0));
+  Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
   // The dimension of each kernel
   const int kernel_dim = this->KernelDim_();
   // The output image size is the spatial size of the output.
index e75789f..bb2aeee 100644 (file)
@@ -560,7 +560,6 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
 
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
-  Tensor* Y = OutputTensorCPU_(0);
   const int N = X.dim32(0), C = X.dim32(1);
   CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
   const int M = filter.dim32(0);
@@ -578,7 +577,8 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
       0,
       "The number of output channels is not divisible by group.");
 
-  ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
+  auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
+  Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
 
   const vector<int> input_dims = GetDims(X);
   const vector<int> output_dims = GetDims(*Y);
@@ -1418,7 +1418,6 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
 
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
-  Tensor* Y = OutputTensorCPU_(0);
   const int C = X.dim32(X.dim() - 1);
   const int G = group_;
   CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
@@ -1435,7 +1434,8 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
   CAFFE_ENFORCE_EQ(
       M % G, 0, "The number of output channels is not divisible by group.");
 
-  ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
+  auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
+  Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
 
   // The col buffer is stored in HWC order as well - kernel_dim, and the height
   // and width.
index d635955..4a2c72b 100644 (file)
@@ -69,6 +69,12 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
     return &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
   }
 
+  Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) {
+    auto* t = &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
+    ReinitializeTensor(t, dims, options.device(CPU));
+    return t;
+  }
+
   T* GetQuantizedOutputData_() {
     return OutputTensorCPU_(0)->template mutable_data<T>();
   }
index 8327b27..88a5a1d 100644 (file)
@@ -122,6 +122,16 @@ class DNNLowPOp : public Operator<CPUContext> {
     }
   }
 
+  Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) {
+    if (dequantize_output_) {
+      return Output(idx, dims, options.device(CPU));
+    } else {
+      auto* t = &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
+      ReinitializeTensor(t, dims, options.device(CPU));
+      return t;
+    }
+  }
+
   T* GetQuantizedOutputData_() {
     if (dequantize_output_) {
       out_temp_.resize(Output(0)->numel());
index 0dda848..7d6ded9 100644 (file)
@@ -101,8 +101,8 @@ class AveragePoolDnnLowPOp final
     GetOutputQuantizationParams_();
 
     auto& X = InputTensorCPU_(0);
-    auto* Y = OutputTensorCPU_(0);
-    ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, X.dim32(1));
+    auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, X.dim32(1));
+    auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
 
     T* Ydata = GetQuantizedOutputData_();
 
@@ -239,9 +239,9 @@ class AveragePoolDnnLowPOp final
     GetOutputQuantizationParams_();
 
     auto& X = InputTensorCPU_(0);
-    auto* Y = OutputTensorCPU_(0);
     int channels = X.dim32(X.ndim() - 1);
-    ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, channels);
+    auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, channels);
+    auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
 
     T* Ydata = GetQuantizedOutputData_();
 
@@ -398,8 +398,8 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
     const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
 
     auto& X = InputTensorCPU_(0);
-    auto* Y = OutputTensorCPU_(0);
-    ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, X.dim32(1));
+    auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, X.dim32(1));
+    auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
 
     T* Ydata = GetQuantizedOutputData_();
 
@@ -544,9 +544,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
     const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
 
     auto& X = InputTensorCPU_(0);
-    auto* Y = OutputTensorCPU_(0);
     int channels = X.dim32(X.ndim() - 1);
-    ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, channels);
+    auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, channels);
+    auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
 
     T* Ydata = GetQuantizedOutputData_();
 
index 37460c8..b7bab4f 100644 (file)
@@ -442,7 +442,6 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CPUContext> {
   bool RunOnDeviceWithOrderNCHW() override {
     const Tensor& X = Input(0);
     auto& filter = Input(1);
-    Tensor* Y = Output(0);
     const int N = X.dim32(0), C = X.dim32(1);
     CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
     const int M = filter.dim32(0);
@@ -452,8 +451,8 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CPUContext> {
     CAFFE_ENFORCE_EQ(C, this->group_);
     CAFFE_ENFORCE_EQ(M, this->group_);
 
-    ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
-    Y->mutable_data<float>();
+    auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
+    Tensor* Y = Output(0, sizes, at::dtype<float>());
 
     DepthwiseArgs args;
     args.batch = X.dim32(0);