bool RunOnDeviceWithOrderNCHW() override {
auto& X = Input(0);
- auto* Y = Output(0);
- ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1));
+ auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
+ auto* Y = Output(0, output_sizes, at::dtype<float>());
if (input_dims_ != X.sizes()) {
// recompile
int* OH,
int* OW) {
Tensor input = caffe2::empty({1, 1, H, W}, at::dtype<float>().device(CPU));
- Tensor output(CPU);
- op->SetOutputSize(input, &output, 1);
- CAFFE_ENFORCE_EQ(output.dim(), 4);
- *OH = output.size(2);
- *OW = output.size(3);
+ auto sizes = op->GetOutputSize(input, 1);
+ CAFFE_ENFORCE_EQ(sizes.size(), 4);
+ *OH = sizes[2];
+ *OW = sizes[3];
}
constexpr int computeMPSAlignOffset(int kernel, int pad) {
bool CudnnConvOp::DoRunWithType() {
auto& X = Input(INPUT);
auto& filter = Input(FILTER);
- auto* Y = Output(0);
// Figure out the output shape
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
CAFFE_ENFORCE(filter.dim() >= 3 && filter.dim() <= 5);
const int M = filter.dim32(0);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, M);
+ auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, M);
+ auto* Y = Output(0, output_sizes, at::dtype<T_Y>());
int N = 0, C = 0, H = 0, W = 0, D = 0, H_out = 0, W_out = 0, D_out = 0;
int group_offset_X = 0, group_offset_Y = 0;
return size;
}
- // Sets the output size. The output channel is manually provided since
+ // Gets the output size. The output channel is manually provided since
// it may not be identical to the input channels.
// This function can be used in the forward functions to obtain the output
// sizes.
// implementations that do not use first-class Tensor objects, such as the
// MKL operator. One can still call this function with dummy
// Tensor objects in order to obtain the sizes.
- // TODO: passing sizes directly rather than Tensor
+ std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
+ CAFFE_ENFORCE_GE(input.dim(), 2);
+ const int inner_size = input.size_from_dim(1);
+ CAFFE_ENFORCE_GT(inner_size, 0);
+ std::vector<int64_t> output_dims;
+ InferOutputSize64(
+ input.sizes(),
+ output_channel,
+ order_,
+ global_pooling_,
+ legacy_pad_,
+ dilation_,
+ stride_,
+ &kernel_,
+ &pads_,
+ &output_dims);
+ return output_dims;
+ }
+
void SetOutputSize(const Tensor& input, Tensor* output, int output_channel) {
const int inner_size = input.size_from_dim(1);
CAFFE_ENFORCE_GT(inner_size, 0);
}
}
+ static void InferOutputSize64(
+ const at::IntList& input_dims,
+ const int output_channel,
+ const StorageOrder order,
+ const bool global_pooling,
+ const LegacyPadding legacy_pad,
+ const std::vector<int>& dilation,
+ const std::vector<int>& stride,
+ std::vector<int>* kernel,
+ std::vector<int>* pads,
+ std::vector<int64_t>* output_dims) {
+ CAFFE_ENFORCE_NE(order, StorageOrder::UNKNOWN);
+ const int ndim = input_dims.size() - 2;
+ output_dims->resize(ndim + 2);
+ output_dims->front() = input_dims.front();
+ if (order == StorageOrder::NCHW) {
+ output_dims->at(1) = output_channel;
+ } else {
+ output_dims->back() = output_channel;
+ }
+ const int offset = order == StorageOrder::NCHW ? 2 : 1;
+ if (global_pooling) {
+ std::copy_n(input_dims.cbegin() + offset, ndim, kernel->begin());
+ std::fill_n(output_dims->begin() + offset, ndim, 1LL);
+ } else {
+ for (int i = 0; i < ndim; ++i) {
+ ComputeSizeAndPad64(
+ input_dims[i + offset],
+ stride[i],
+ kernel->at(i),
+ dilation[i],
+ legacy_pad,
+ &pads->at(i),
+ &pads->at(i + ndim),
+ &output_dims->at(i + offset));
+ }
+ }
+ }
+
// ComputePads could be used in backward functions to figure out the padding
// values for the given input.
void ComputePads(const vector<int>& dims) {
}
}
+ static inline void ComputeSizeAndPad64(
+ const int in_size,
+ const int stride,
+ const int kernel,
+ const int dilation,
+ LegacyPadding legacy_pad,
+ int* pad_head,
+ int* pad_tail,
+ int64_t* out_size) {
+ const int dkernel = dilation * (kernel - 1) + 1;
+ switch (legacy_pad) {
+ case LegacyPadding::NOTSET:
+ // We will just use the direct padding head and tail values, but we
+ // will verify that they are non-negative.
+ CAFFE_ENFORCE_GE(in_size + *pad_head + *pad_tail, dkernel);
+ *out_size = static_cast<int>(
+ static_cast<float>(in_size + *pad_head + *pad_tail - dkernel) /
+ stride +
+ 1);
+ break;
+ case LegacyPadding::VALID:
+ *pad_head = 0;
+ *pad_tail = 0;
+ *out_size = (in_size - dkernel) / stride + 1;
+ break;
+ case LegacyPadding::SAME: {
+ CAFFE_ENFORCE(
+ 1 == dilation, "Dilation not supported for legacy padding.");
+ int legacy_target_size = (in_size + stride - 1) / stride;
+ int pad_needed = (legacy_target_size - 1) * stride + kernel - in_size;
+ if (CAFFE2_PAD_HEAD_MORE) {
+ *pad_head = (pad_needed + 1) / 2;
+ } else {
+ *pad_head = pad_needed / 2;
+ }
+ *pad_tail = pad_needed - *pad_head;
+ *out_size = (in_size + pad_needed - dkernel) / stride + 1;
+ break;
+ }
+ case LegacyPadding::CAFFE_LEGACY_POOLING:
+ // This is in order to adapt Caffe's pooling padding case. In this case,
+ // we will only use pad_head and will compute pad_tail to match the
+ // old caffe pooling strategy. Also see caffe2_legacy.proto for more
+ // details.
+ CAFFE_ENFORCE_GE(*pad_head, 0);
+ // Here, notice that caffe casts UP while caffe2 casts DOWN for the
+ // output size computation.
+ *out_size = std::ceil(
+ static_cast<float>(in_size + *pad_head * 2 - kernel) / stride + 1);
+ // If we have padding, caffe also ensures that the last pooling starts
+ // strictly inside the image (instead of at the padding); otherwise clip
+ // the last.
+ if (*pad_head > 0 && (*out_size - 1) * stride >= in_size + *pad_head) {
+ --*out_size;
+ }
+ // Now, compare the output size with the standard Caffe2 output size.
+ // The
+ // caffe2 standard output size should always be no larger than the
+ // output
+ // size of caffe.
+ int standard_out_size = static_cast<int>(
+ static_cast<float>(in_size + *pad_head * 2 - kernel) / stride + 1);
+ CAFFE_ENFORCE_GE(
+ *out_size,
+ standard_out_size,
+ "This should never happen. If this happens, double check the logic "
+ "above.");
+ if (*out_size > standard_out_size) {
+ LOG(WARNING)
+ << "You are hitting a case where Caffe's legacy padding calculation "
+ "is hit. This leads to inefficient and sometimes incorrect "
+ "results. We are keeping this behavior for backward compatibility"
+ ", but you are strongly recommended to move away from it.";
+ }
+ *pad_tail = *pad_head + stride * (*out_size - standard_out_size);
+ break;
+ }
+ }
+
// Accessors for 2D conv params.
inline int pad_t() const {
bool RunOnDeviceWithOrderNCHW() override {
const Tensor& X = Input(0);
auto& filter = Input(1);
- Tensor* Y = Output(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
const int M = filter.dim32(0);
CAFFE_ENFORCE_EQ(this->kernel_w(), 3);
CAFFE_ENFORCE_EQ(this->kernel_h(), 3);
CAFFE_ENFORCE_EQ(this->stride_h(), this->stride_w());
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, filter.dim32(0));
+ auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = Output(0, sizes, at::dtype<float>());
DepthwiseArgs args;
args.batch = X.dim32(0);
args.in_rows = X.dim32(2);
M,
dY.dim32(2),
dY.dim32(3)));
-
+
auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<float>());
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
cudnn_wrapper_.inline_cudnn_handle(),
bool MIOPENConvOp::DoRunWithType() {
auto& X = Input(INPUT);
auto& Weight = Input(FILTER);
- auto* Y = Output(0);
// Figure out the output shape
CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5);
"Conv op with MIOpen engine is supported only for 2D convolutions");
const int M = Weight.dim32(0);
- ConvPoolOpBase<HIPContext>::SetOutputSize(X, Y, M);
+ auto sizes = ConvPoolOpBase<HIPContext>::GetOutputSize(X, M);
+ auto* Y = Output(0, sizes, at::dtype<T_Y>());
int N = X.dim32(0);
int C = X.dim32(1);
template <typename T, typename M>
bool DoRunWithType() {
auto& X = Input(0);
- auto* Y = Output(0);
int N = 0, C = 0, H = 0, W = 0, D = 0;
int N_out = 0, C_out = 0, H_out = 0, W_out = 0;
CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
C = X.dim32(1);
H = X.dim32(2);
W = X.ndim() > 3 ? X.dim32(3) : 1;
- ConvPoolOpBase::SetOutputSize(X, Y, C);
+ auto sizes = ConvPoolOpBase::GetOutputSize(X, C);
+ auto* Y = Output(0, sizes, at::dtype<T>());
N_out = Y->dim32(0);
C_out = Y->dim32(1);
template <typename T>
bool MaxPoolWithIndexOp::DoRunWithType() {
auto& X = Input(0);
- auto* Y = Output(0);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
+ auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
+ auto* Y = Output(0, sizes, at::dtype<T>());
+
int output_size = Y->numel();
auto* mask = Output(1, {output_size}, at::dtype<int>());
template <>
bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& X = Input(0);
- auto* Y = Output(0);
const int num = X.dim32(0);
const int channels = X.dim32(1);
const int height = X.dim32(2);
const int width = X.dim32(3);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, channels);
+ auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, channels);
+ auto* Y = Output(0, sizes, at::dtype<float>());
+
const int output_size = Y->numel();
const int padded_height = Y->dim32(2);
const int padded_width = Y->dim32(3);
template<>
bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& X = Input(0);
- auto* Y = Output(0);
const int num = X.dim32(0);
const int height = X.dim32(1);
const int width = X.dim32(2);
const int channels = X.dim32(3);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, channels);
+ auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, channels);
+ auto* Y = Output(0, sizes, at::dtype<float>());
+
const int output_size = Y->numel();
const int padded_height = Y->dim32(1);
const int padded_width = Y->dim32(2);
template<>
bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& dY = Input(0);
-
+
auto* dX = Output(0, { dY.dim32(0),
dY.dim32(1),
dY.dim32(2) - pad_t() - pad_b(),
template<>
bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& dY = Input(0);
-
+
auto* dX = Output(0, { dY.dim32(0),
dY.dim32(1) - pad_t() - pad_b(),
dY.dim32(2) - pad_l() - pad_r(),
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
- auto* Y = Output(0);
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, C);
+ auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, C);
+ auto* Y = Output(0, sizes, at::dtype<T>());
const T* X_data = X.template data<T>();
T* Y_data = Y->template mutable_data<T>();
const Tensor& X = InputTensorCPU_(INPUT);
int N = X.dim32(0);
- Tensor* Y = OutputTensorCPU_(0);
- this->SetOutputSize(X, Y, filter.dim32(0));
+ auto sizes = this->GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
const int output_image_size = this->GetDimsSize(*Y);
if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
- Tensor* Y = OutputTensorCPU_(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
const int M = filter.dim32(0);
0,
"The number of output channels is not divisible by group.");
- this->SetOutputSize(X, Y, filter.dim32(0));
+ auto sizes = this->GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
const vector<int> input_dims = GetDims(X);
const vector<int> output_dims = GetDims(*Y);
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
- Tensor* Y = OutputTensorCPU_(0);
const int N = X.dim32(0), C = X.dim32(X.ndim() - 1);
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
const int M = filter.dim32(0);
CAFFE_ENFORCE_EQ(filter.dim32(filter.ndim() - 1), C / group_);
- this->SetOutputSize(X, Y, filter.dim32(0));
+ auto sizes = this->GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
// The dimension of each kernel
const int kernel_dim = this->KernelDim_();
// The output image size is the spatial size of the output.
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
- Tensor* Y = OutputTensorCPU_(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
const int M = filter.dim32(0);
0,
"The number of output channels is not divisible by group.");
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
const vector<int> input_dims = GetDims(X);
const vector<int> output_dims = GetDims(*Y);
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
- Tensor* Y = OutputTensorCPU_(0);
const int C = X.dim32(X.dim() - 1);
const int G = group_;
CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
CAFFE_ENFORCE_EQ(
M % G, 0, "The number of output channels is not divisible by group.");
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
// The col buffer is stored in HWC order as well - kernel_dim, and the height
// and width.
return &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
}
+ Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) {
+ auto* t = &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
+ ReinitializeTensor(t, dims, options.device(CPU));
+ return t;
+ }
+
T* GetQuantizedOutputData_() {
return OutputTensorCPU_(0)->template mutable_data<T>();
}
}
}
+ Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) {
+ if (dequantize_output_) {
+ return Output(idx, dims, options.device(CPU));
+ } else {
+ auto* t = &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
+ ReinitializeTensor(t, dims, options.device(CPU));
+ return t;
+ }
+ }
+
T* GetQuantizedOutputData_() {
if (dequantize_output_) {
out_temp_.resize(Output(0)->numel());
GetOutputQuantizationParams_();
auto& X = InputTensorCPU_(0);
- auto* Y = OutputTensorCPU_(0);
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, X.dim32(1));
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, X.dim32(1));
+ auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
GetOutputQuantizationParams_();
auto& X = InputTensorCPU_(0);
- auto* Y = OutputTensorCPU_(0);
int channels = X.dim32(X.ndim() - 1);
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, channels);
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, channels);
+ auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
auto& X = InputTensorCPU_(0);
- auto* Y = OutputTensorCPU_(0);
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, X.dim32(1));
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, X.dim32(1));
+ auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
auto& X = InputTensorCPU_(0);
- auto* Y = OutputTensorCPU_(0);
int channels = X.dim32(X.ndim() - 1);
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, channels);
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, channels);
+ auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
bool RunOnDeviceWithOrderNCHW() override {
const Tensor& X = Input(0);
auto& filter = Input(1);
- Tensor* Y = Output(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
const int M = filter.dim32(0);
CAFFE_ENFORCE_EQ(C, this->group_);
CAFFE_ENFORCE_EQ(M, this->group_);
- ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
- Y->mutable_data<float>();
+ auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
+ Tensor* Y = Output(0, sizes, at::dtype<float>());
DepthwiseArgs args;
args.batch = X.dim32(0);