From a7b82a44c40825ccc02187c294a58d49d54d1448 Mon Sep 17 00:00:00 2001 From: "Gu, Jinghui" Date: Wed, 3 Apr 2019 10:29:19 -0700 Subject: [PATCH] Upgrade mkldnn-bridge for dnnlowp support (#16308) Summary: The mkldnn-bridge is upgraded in this PR to support DNNLOWP operators. Meanwhile, APIs have been updated in caffe2 to use latest version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16308 Differential Revision: D14697018 Pulled By: yinghai fbshipit-source-id: ca952589098accb08295fd5aa92924c61e74d69c --- caffe2/ideep/ideep_utils.h | 27 ++- caffe2/ideep/operators/conv_fusion_op.cc | 219 --------------------- caffe2/ideep/operators/conv_op.cc | 237 ++++++++++++++++++----- caffe2/ideep/operators/conv_pool_base_op.h | 9 +- caffe2/ideep/operators/conv_transpose_op.cc | 4 +- caffe2/ideep/operators/operator_fallback_ideep.h | 27 ++- caffe2/ideep/operators/pool_op.cc | 23 ++- caffe2/ideep/operators/utility_ops.cc | 4 +- caffe2/ideep/utils/ideep_operator.h | 21 +- caffe2/opt/optimize_ideep.cc | 37 ++-- caffe2/python/ideep/convfusion_op_test.py | 108 +++++++++++ caffe2/python/pybind_state_ideep.cc | 8 +- 12 files changed, 405 insertions(+), 319 deletions(-) delete mode 100644 caffe2/ideep/operators/conv_fusion_op.cc diff --git a/caffe2/ideep/ideep_utils.h b/caffe2/ideep/ideep_utils.h index db4195c..11adf8c 100644 --- a/caffe2/ideep/ideep_utils.h +++ b/caffe2/ideep/ideep_utils.h @@ -12,16 +12,41 @@ namespace caffe2 { enum ConvAlgorithm { CONV_ALGORITHM_AUTO = 0, CONV_ALGORITHM_WINOGRAD = 1, - CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1 + CONV_ALGORITHM_MAX +}; + +enum FusionType { + FUSION_UNKNOWN = 0, + FUSION_CONV_RELU = 1, + FUSION_CONV_SUM = 2, + FUSION_CONV_SUM_RELU = 3, + FUSION_MAX }; #define USE_IDEEP_DEF_ALIASES() \ + /* the hash key of cahced operator generated by iDEEP */ \ + using ikey = ideep::key_t; \ + /* the tensor type created/handled by iDEEP */ \ using itensor = ideep::tensor; \ + /* the date layout of iDEEP tensor */ \ using iformat = ideep::format; \ + /* the scales for iDEEP tensor with different data type */ \ + using iscale = ideep::scale_t; \ + /* the detial algorithm for iDEEP operators, e.g. winograd */ \ using ialgo = ideep::algorithm; \ + /* the kind of propagation for iDEEP operators, e.g. forward, training */ \ using iprop = ideep::prop_kind; \ + /* the kind of low precision operators, e.g. signed/unsigned activation */ \ + using ilowp_kind = ideep::lowp_kind; \ + /* the kind of padding, usually set as zero padding */ \ using ipadding = ideep::padding_kind; \ + /* the data type of iDEEP tensor, e.g. f32, u8, s8 */ \ + using idtype = ideep::tensor::data_type; \ + /* the descriptor of iDEEP tensor */ \ + using itdesc = ideep::tensor::descriptor; \ + /* the attribute for operator to describe the details of inputs&fusion */ \ using iattr = ideep::descriptor_group::attr_t; \ + /* the detail flags for batch normalization */ \ using ibn_flag = ideep::batch_normalization_flag; } // namespace caffe2 diff --git a/caffe2/ideep/operators/conv_fusion_op.cc b/caffe2/ideep/operators/conv_fusion_op.cc deleted file mode 100644 index ff23991..0000000 --- a/caffe2/ideep/operators/conv_fusion_op.cc +++ /dev/null @@ -1,219 +0,0 @@ -#include - -namespace caffe2 { - -class IDEEPConvFusionOp final : public IDEEPConvPoolOpBase { - public: - USE_IDEEP_DEF_ALIASES(); - USE_IDEEP_CONV_POOL_BASE_FUNCTIONS(); - - enum FusionType { - FUSION_UNKNOWN = 0, - FUSION_CONV_RELU = 1, - FUSION_CONV_SUM = 2, - FUSION_CONV_SUM_RELU = 3, - FUSION_MAX = FUSION_CONV_SUM_RELU + 1, - }; - - IDEEPConvFusionOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPConvPoolOpBase(operator_def, ws), - fusion_type_(static_cast( - OperatorBase::GetSingleArgument("fusion_type", 0))), - training_mode_( - OperatorBase::GetSingleArgument("training_mode", 0)), - conv_algorithm_( - OperatorBase::GetSingleArgument("conv_algorithm", CONV_ALGORITHM_AUTO)) { - OPERATOR_NEEDS_FEATURE( - pad_l() == pad_r() && pad_t() == pad_b(), - "Uneven padding not supported."); - OPERATOR_NEEDS_FEATURE(group_ == 1, "Group not supported."); - OPERATOR_NEEDS_FEATURE( - fusion_type_ > FUSION_UNKNOWN && fusion_type_ < FUSION_MAX, - "Undefined Conv fusion type.", - fusion_type_); - - // Check kernel only if we are doing conv. The reason is that a - // few other ops, like PadImage, are also using this base class. We really - // need to clean this up. - for (int dim = 0; dim < kernel_.size(); ++dim) { - CAFFE_ENFORCE_GE(pads_[dim], 0); - CAFFE_ENFORCE_GE(pads_[kernel_.size() + dim], 0); - CAFFE_ENFORCE( - kernel_[dim], - "If you are doing convolution, you will need to set " - "explicitly the kernel size."); - } - } - ~IDEEPConvFusionOp() override {} - - bool RunOnDeviceWithOrderNCHW() override { - const auto& X = Input(INPUT_X); - const auto& filter = Input(FILTER); - auto* Y = Output(OUTPUT); - auto Y_dims_conv = CalcOutputDims(X, filter.get_dim(0)); - auto attr = [this]() { - return (fusion_type_ == FUSION_CONV_RELU) - ? iattr::fuse_relu() - : ((fusion_type_ == FUSION_CONV_SUM) - ? iattr::fuse_sum() - : ((fusion_type_ == FUSION_CONV_SUM_RELU) ? iattr::residual() - : iattr())); - }; - auto last_input = [this]() { - return (fusion_type_ == FUSION_CONV_RELU) ? BIAS_OR_INPUT_S : INPUT_S; - }; - - CAFFE_ENFORCE(4 == X.ndims()); - CAFFE_ENFORCE(4 == filter.ndims()); - CAFFE_ENFORCE(filter.get_dim(2) == kernel_h()); - CAFFE_ENFORCE(filter.get_dim(3) == kernel_w()); - CAFFE_ENFORCE( - X.get_dim(1) == filter.get_dim(1) * group_, - "Convolution fusion op: input channels does not match: " - "# of input channels ", - X.get_dim(1), - " is not equal to kernel channels * group:", - filter.get_dim(1), - "*", - group_); - - ideep::algorithm aalgorithm = ideep::algorithm::convolution_direct; - if (conv_algorithm_ == CONV_ALGORITHM_WINOGRAD) { - aalgorithm = ideep::algorithm::convolution_winograd; - } - - bool weights_changed = - (cached_weights_descriptor_ != filter.get_descriptor()); - if (weights_changed && !training_mode_) { - cached_weights_descriptor_ = filter.get_descriptor(); - filter_ = filter; - auto expected_descriptor = - ideep::convolution_forward::expected_weights_descriptor( - filter.get_dims()); - if (filter_.get_descriptor() != expected_descriptor) { - filter_.init( - expected_descriptor); - ideep::reorder::compute(filter, filter_); - } - } - - if (InputSize() > last_input()) { - ideep::convolution_forward::compute( - X, - training_mode_ ? filter : filter_, - Input(BIAS_OR_INPUT_S), - Y_dims_conv, - *Y, - stride_, - dilation_, - pad_tl(), - pad_br(), - group_, - attr(), - aalgorithm); - } else { - ideep::convolution_forward::compute( - X, - training_mode_ ? filter : filter_, - Y_dims_conv, - *Y, - stride_, - dilation_, - pad_tl(), - pad_br(), - group_, - attr(), - aalgorithm); - } - - if (fusion_type_ != FUSION_CONV_RELU) { - CAFFE_ENFORCE( - Y == &(Input(InputSize() - 1)), - "Convolution fusion op: InPlace is enforced for sum fusion."); - } - - return true; - } - - private: - FusionType fusion_type_; - bool training_mode_; - int conv_algorithm_; - ideep::tensor filter_; - ideep::tensor::descriptor cached_weights_descriptor_; - - INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S); - OUTPUT_TAGS(OUTPUT); -}; - -REGISTER_IDEEP_OPERATOR(ConvFusion, IDEEPConvFusionOp); - -const char* kConvFusionDoc = R"DOC( -Note that other parameters, such as the stride and -kernel size, or the pads' sizes in each direction are not necessary for input -because they are provided by the ConvPoolOpBase operator. Various dimension -checks are done implicitly, and the sizes are specified in the Input docs for -this operator. As is expected, the filter is convolved with a subset of the -image and the bias is added; this is done throughout the image data and the -output is computed. As a side note on the implementation layout: -conv_op_impl.h is the templated implementation of the conv_op.h file, which is -why they are separate files. -)DOC"; - -std::function ConvFusionDocGenerator(const char* dim) { - return [=](OpSchema& schema) { - string doc = R"DOC( -The convolution fusion operator consumes an input vector, a {dim}filter blob, -a bias blob and another input vector and computes the output. This operator -gives the chance to fuse the ReLU or element-wise Sum with a convolution -operator. {conv_fusion_doc})DOC"; - c10::ReplaceAll(doc, "{dim}", dim); - c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc); - schema.SetDoc(doc); - schema.Input( - 0, - "X", - "Input data blob from previous layer; has size (N x C x H x W), " - "where N is the batch size, C is the number of channels, " - "and H and W are the height and width. Note that this is for the NCHW " - "usage. On the other hand, the NHWC Op has a different set of " - "dimension constraints. "); - schema.Input( - 1, - "filter", - "The filter blob that will be used in the " - "convolutions; has size (M x C x kH x kW), where C is the number of " - "channels, and kH and kW are the height and width of the kernel."); - schema.Input( - 2, - "bias", - "The 1D bias blob that is added through the " - "convolution; has size (M)."); - schema.Input( - 3, - "S", - "Input data blob for element-wise Sum fusion from previous layer; " - "has the same size of convolution output. Its input index should " - "be 2 if no bias for this convolution, and it MUST be inplace with " - "output Y."); - schema.Output( - 0, - "Y", - "Output data blob that contains the result of the " - "convolution fusion. The output dimensions are functions of the kernel " - "size, stride size, and pad lengths." - ""); - }; -} - -OPERATOR_SCHEMA(ConvFusion) - .NumInputs(2, 4) - .NumOutputs(1) - .TensorInferenceFunction(ConvPoolOpBase::TensorInferenceForConv) - .CostInferenceFunction(OpSchema::CostInferenceFunctionType( - ConvPoolOpBase::CostInferenceForConv)) - .Arg("fusion_type", "Which fusion type is used") - .AllowInplace({{2, 0}, {3, 0}}) - .FillUsing(ConvFusionDocGenerator("")); - -} // namespace caffe2 diff --git a/caffe2/ideep/operators/conv_op.cc b/caffe2/ideep/operators/conv_op.cc index 4ecc334..d83c995 100644 --- a/caffe2/ideep/operators/conv_op.cc +++ b/caffe2/ideep/operators/conv_op.cc @@ -2,116 +2,258 @@ namespace caffe2 { -class IDEEPConvOp final : public IDEEPConvPoolOpBase { +class IDEEPConvOp : public IDEEPConvPoolOpBase { public: USE_IDEEP_DEF_ALIASES(); USE_IDEEP_CONV_POOL_BASE_FUNCTIONS(); IDEEPConvOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPConvPoolOpBase(operator_def, ws), - training_mode_( - OperatorBase::GetSingleArgument("training_mode", 0)), - conv_algorithm_( - OperatorBase::GetSingleArgument("conv_algorithm", CONV_ALGORITHM_AUTO)) { + : IDEEPConvPoolOpBase(operator_def, ws) { + OPERATOR_NEEDS_FEATURE( + order_ == StorageOrder::NCHW, "Unsupported storage order."); OPERATOR_NEEDS_FEATURE( pad_l() == pad_r() && pad_t() == pad_b(), "Uneven padding not supported."); + + fusion_type_ = FUSION_UNKNOWN; + last_input_ = BIAS_OR_INPUT_S; + + training_mode_ = OperatorBase::GetSingleArgument("training_mode", 0); + pk_ = training_mode_ ? iprop::forward_training : iprop::forward_inference; + + algo_ = ialgo::convolution_direct; + auto conv_algorithm = OperatorBase::GetSingleArgument( + "conv_algorithm", CONV_ALGORITHM_AUTO); + if (conv_algorithm == CONV_ALGORITHM_WINOGRAD) { + algo_ = ialgo::convolution_winograd; + } } - ~IDEEPConvOp() override {} + virtual ~IDEEPConvOp() {} bool RunOnDeviceWithOrderNCHW() override { - const auto& X = Input(INPUT); + const auto& X = Input(INPUT_X); const auto& filter = Input(FILTER); auto* Y = Output(OUTPUT); - auto Y_dims = CalcOutputDims(X, filter.get_dim(0)); + auto grouped = filter.is_grouped() ? 1 : 0; + auto Y_dims_conv = CalcOutputDims( + X, + grouped ? (filter.get_dim(0) * filter.get_dim(1)) : filter.get_dim(0)); CAFFE_ENFORCE(4 == X.ndims()); - CAFFE_ENFORCE(4 == filter.ndims()); - CAFFE_ENFORCE(filter.get_dim(2) == kernel_h()); - CAFFE_ENFORCE(filter.get_dim(3) == kernel_w()); + CAFFE_ENFORCE(4 == filter.ndims() || (grouped && (group_ > 1))); + CAFFE_ENFORCE_EQ(filter.get_dim(2 + grouped), kernel_h()); + CAFFE_ENFORCE_EQ(filter.get_dim(3 + grouped), kernel_w()); CAFFE_ENFORCE( - X.get_dim(1) == filter.get_dim(1) * group_, + X.get_dim(1) == filter.get_dim(1 + grouped) * group_, "Convolution op: input channels does not match: # of input channels ", X.get_dim(1), " is not equal to kernel channels * group:", - filter.get_dim(1), + filter.get_dim(1 + grouped), "*", group_); - ideep::algorithm aalgorithm = ideep::algorithm::convolution_direct; - if (conv_algorithm_ == CONV_ALGORITHM_WINOGRAD) { - aalgorithm = ideep::algorithm::convolution_winograd; - } - bool weights_changed = (cached_weights_descriptor_ != filter.get_descriptor()); if (weights_changed && !training_mode_) { - cached_weights_descriptor_ = filter.get_descriptor(); - auto filter_in = filter; + op_key_.clear(); + cached_weights_descriptor_ = filter.dup_descriptor(); + auto filter_in = filter.as_weights(); filter_in.make_group(group_); + auto expected_descriptor = ideep::convolution_forward::expected_weights_descriptor( filter_in.get_dims(), - filter_in.get_data_type(), + idtype::f32, stride_, pad_tl(), pad_br(), dilation_, group_, - aalgorithm); - filter_.init( - expected_descriptor); - ideep::reorder::compute(filter_in, filter_); + algo_, + pk_, + idtype::f32, + X.get_dims()); + if (filter_in.get_descriptor() != expected_descriptor) { + filter_.init(expected_descriptor); + filter_.feed_from(filter_in); + } else { + filter_ = filter_in; + } + } + + if (cached_X_descriptor_ != X.get_descriptor()) { + op_key_.clear(); + cached_X_descriptor_ = X.dup_descriptor(); } - // NB: actually, in the case when `group_ > 1`, IDEEP will create - // an itermediate tensor for each run below. However, this tensor is merely - // a view of of the weights and there is no actual data copy, so I'll let it - // go now. If we encounter performance surprise when convoluting with group - // > 1, this is the first place to check and we need to do the same cache - // trick as above - if (InputSize() > BIAS) { + if (InputSize() > last_input_) { ideep::convolution_forward::compute( + op_key_, X, training_mode_ ? filter : filter_, - Input(BIAS), - Y_dims, + Input(BIAS_OR_INPUT_S), + Y_dims_conv, *Y, stride_, dilation_, pad_tl(), pad_br(), group_, - ideep::descriptor_group::attr_t(), - aalgorithm); + dummy_scale_, + dummy_scale_, + dummy_scale_, + attr_, + algo_, + pk_); } else { ideep::convolution_forward::compute( + op_key_, X, training_mode_ ? filter : filter_, - Y_dims, + Y_dims_conv, *Y, stride_, dilation_, pad_tl(), pad_br(), group_, - ideep::descriptor_group::attr_t(), - aalgorithm); + dummy_scale_, + dummy_scale_, + dummy_scale_, + attr_, + algo_, + pk_); + } + + if (fusion_type_ == FUSION_CONV_SUM + && fusion_type_ == FUSION_CONV_SUM_RELU) { + CAFFE_ENFORCE_EQ(Y, &(Input(InputSize() - 1)), + "Convolution fusion op: InPlace is enforced for sum fusion."); } return true; } - private: - INPUT_TAGS(INPUT, FILTER, BIAS); + protected: + iprop pk_; + ialgo algo_; + iattr attr_; + ikey op_key_; + int last_input_; + bool training_mode_; + FusionType fusion_type_; + itensor filter_; + iscale dummy_scale_; + itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_; + + INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S); OUTPUT_TAGS(OUTPUT); +}; - bool training_mode_; - int conv_algorithm_; - ideep::tensor filter_; - ideep::tensor::descriptor cached_weights_descriptor_; +class IDEEPConvFusionOp final : public IDEEPConvOp { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_CONV_POOL_BASE_FUNCTIONS(); + + IDEEPConvFusionOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPConvOp(operator_def, ws) { + CAFFE_ENFORCE(OperatorBase::HasArgument("fusion_type"), + "You should specify the fusion type"); + fusion_type_ = static_cast( + OperatorBase::GetSingleArgument("fusion_type", FUSION_UNKNOWN)); + OPERATOR_NEEDS_FEATURE( + fusion_type_ > FUSION_UNKNOWN && fusion_type_ < FUSION_MAX, + "Undefined Conv fusion type.", + fusion_type_); + + switch (fusion_type_) { + case FUSION_CONV_RELU: + attr_ = iattr::fuse_relu(); + last_input_ = BIAS_OR_INPUT_S; + break; + case FUSION_CONV_SUM: + attr_ = iattr::fuse_sum(); + last_input_ = INPUT_S; + break; + case FUSION_CONV_SUM_RELU: + attr_ = iattr::residual(); + last_input_ = INPUT_S; + break; + default: + CAFFE_THROW("Unsupported conv fusion type!"); + } + } + virtual ~IDEEPConvFusionOp() {} }; +const char* kConvFusionDoc = R"DOC( +Note that other parameters, such as the stride and +kernel size, or the pads' sizes in each direction are not necessary for input +because they are provided by the ConvPoolOpBase operator. Various dimension +checks are done implicitly, and the sizes are specified in the Input docs for +this operator. As is expected, the filter is convolved with a subset of the +image and the bias is added; this is done throughout the image data and the +output is computed. As a side note on the implementation layout: +conv_op_impl.h is the templated implementation of the conv_op.h file, which is +why they are separate files. +)DOC"; + +std::function ConvFusionDocGenerator(const char* dim) { + return [=](OpSchema& schema) { + string doc = R"DOC( +The convolution fusion operator consumes an input vector, a {dim}filter blob, +a bias blob and another input vector and computes the output. This operator +gives the chance to fuse the ReLU or element-wise Sum with a convolution +operator. {conv_fusion_doc})DOC"; + c10::ReplaceAll(doc, "{dim}", dim); + c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc); + schema.SetDoc(doc); + schema.Input( + 0, + "X", + "Input data blob from previous layer; has size (N x C x H x W), " + "where N is the batch size, C is the number of channels, " + "and H and W are the height and width. Note that this is for the NCHW " + "usage. On the other hand, the NHWC Op has a different set of " + "dimension constraints. "); + schema.Input( + 1, + "filter", + "The filter blob that will be used in the " + "convolutions; has size (M x C x kH x kW), where C is the number of " + "channels, and kH and kW are the height and width of the kernel."); + schema.Input( + 2, + "bias", + "The 1D bias blob that is added through the " + "convolution; has size (M)."); + schema.Input( + 3, + "S", + "Input data blob for element-wise Sum fusion from previous layer; " + "has the same size of convolution output. Its input index should " + "be 2 if no bias for this convolution, and it MUST be inplace with " + "output Y."); + schema.Output( + 0, + "Y", + "Output data blob that contains the result of the " + "convolution fusion. The output dimensions are functions of the kernel " + "size, stride size, and pad lengths." + ""); + }; +} + +OPERATOR_SCHEMA(ConvFusion) + .NumInputs(2, 4) + .NumOutputs(1) + .TensorInferenceFunction(ConvPoolOpBase::TensorInferenceForConv) + .CostInferenceFunction(OpSchema::CostInferenceFunctionType( + ConvPoolOpBase::CostInferenceForConv)) + .Arg("fusion_type", "Which fusion type is used") + .AllowInplace({{2, 0}, {3, 0}}) + .FillUsing(ConvFusionDocGenerator("")); + class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase { public: USE_IDEEP_DEF_ALIASES(); @@ -131,7 +273,7 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase { "In order to backward propagate weights correctly, " "please set training_mode=1"); } - ~IDEEPConvGradientOp() override {} + virtual ~IDEEPConvGradientOp() {} bool RunOnDeviceWithOrderNCHW() override { const auto& X = Input(INPUT); @@ -190,6 +332,7 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase { }; REGISTER_IDEEP_OPERATOR(Conv, IDEEPConvOp); +REGISTER_IDEEP_OPERATOR(ConvFusion, IDEEPConvFusionOp); REGISTER_IDEEP_OPERATOR(ConvGradient, IDEEPConvGradientOp); } // namespace caffe2 diff --git a/caffe2/ideep/operators/conv_pool_base_op.h b/caffe2/ideep/operators/conv_pool_base_op.h index a576461..e4170a3 100644 --- a/caffe2/ideep/operators/conv_pool_base_op.h +++ b/caffe2/ideep/operators/conv_pool_base_op.h @@ -11,10 +11,7 @@ namespace caffe2 { class IDEEPConvPoolOpBase : public ConvPoolOpBase { public: IDEEPConvPoolOpBase(const OperatorDef& operator_def, Workspace* ws) - : ConvPoolOpBase(operator_def, ws) { - OPERATOR_NEEDS_FEATURE( - order_ == StorageOrder::NCHW, "Unsupported storage order."); - } + : ConvPoolOpBase(operator_def, ws) {} virtual ~IDEEPConvPoolOpBase() {} inline const ideep::tensor& Input(int index) { @@ -35,7 +32,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase { ideep::tensor::dims CalcOutputDims( const ideep::tensor& input, int output_channel) { - CAFFE_ENFORCE(input.get_descriptor().get_size() > 0); + CAFFE_ENFORCE_GT(input.get_size(), 0); ideep::tensor::dims output_dims; const auto input_dims = input.get_dims(); std::vector input_Tdims( @@ -43,7 +40,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase { InferOutputSize( input_Tdims, output_channel, - order_, + StorageOrder::NCHW, //order_, global_pooling_, legacy_pad_, dilation_, diff --git a/caffe2/ideep/operators/conv_transpose_op.cc b/caffe2/ideep/operators/conv_transpose_op.cc index e05ee71..85981c2 100644 --- a/caffe2/ideep/operators/conv_transpose_op.cc +++ b/caffe2/ideep/operators/conv_transpose_op.cc @@ -77,7 +77,7 @@ class IDEEPConvTransposeOp final : public IDEEPConvTransposeUnpoolBase { // we have to do explicit conversion here. filter_in.set_public_format(ideep::format::iohw); filter_.init(expected_descriptor); - ideep::reorder::compute(filter_in, filter_); + filter_.feed_from(filter_in); } // TODO: The code below works around correctness issues with particular input shapes @@ -178,7 +178,7 @@ class IDEEPConvTransposeGradientOp final : public IDEEPConvTransposeUnpoolBase { // we have to do explicit conversion here. filter_in.set_public_format(ideep::format::iohw); filter_.init(expected_descriptor); - ideep::reorder::compute(filter_in, filter_); + filter_.feed_from(filter_in); // TODO: The code below works around correctness issues with particular input shapes // in MKL-DNN v0.17, will be removed with the fixes in MKL-DNN 0.18. diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index 2dc3612..a022cac 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -82,20 +82,29 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { bool RunOnDevice() override { for (int i = 0; i < InputSize(); ++i) { - if (InputIsType(i) && - Input(i).get_data_type() == itensor::data_type::f32) { + if (InputIsType(i) + && (Input(i).has_scale() + || Input(i).get_data_type() == idtype::f32)) { auto& input = Input(i); if (input_share_[i]) { local_input_blobs_[i]->Reset(); + input_share_[i] = false; } - input_share_[i] = false; auto dtensor = BlobGetMutableTensor(local_input_blobs_[i], CPU); dtensor->Resize(input.get_dims()); - if (input.is_public_format()) { + // If fallback from INT8, the public format of original input is nhwc. + // While the required format is nchw, need to reorder to nchw. + if (input.get_public_format() == iformat::nhwc) { + itensor temp_ten ({input.get_dims(), idtype::f32, iformat::nchw}, + dtensor->template mutable_data()); + temp_ten.feed_from(input); + } else if (!input.need_reorder()) { + CAFFE_ENFORCE(!input.has_scale(), + "Incorrect invocation of get_data_handle"); dtensor->ShareExternalPointer( static_cast(input.get_data_handle())); } else { - input.reorder_to(dtensor->template mutable_data()); + input.to_public(dtensor->template mutable_data()); } } else { VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy."; @@ -143,12 +152,14 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { itensor::dims dst_dims (src_dims.begin(), src_dims.end()); auto dtensor = dst->template GetMutable(); if (dtensor->get_dims() != dst_dims) { - dtensor->resize(dst_dims, itensor::data_type::f32); + dtensor->resize(dst_dims, idtype::f32); } if (output_inplace_[i]) { - dtensor->reorder_from(dst_dims, itensor::data_type::f32, - const_cast(src.raw_data())); + dtensor->feed_from(dst_dims, idtype::f32, + const_cast(src.raw_data())); } else { + CAFFE_ENFORCE(!dtensor->has_scale(), + "Incorrect invocation of set_data_handle"); dtensor->set_data_handle(const_cast(src.raw_data())); } } else { diff --git a/caffe2/ideep/operators/pool_op.cc b/caffe2/ideep/operators/pool_op.cc index 45abb37..54baf18 100644 --- a/caffe2/ideep/operators/pool_op.cc +++ b/caffe2/ideep/operators/pool_op.cc @@ -8,9 +8,7 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase { USE_IDEEP_CONV_POOL_BASE_FUNCTIONS(); IDEEPPoolOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPConvPoolOpBase(operator_def, ws), - training_mode_( - OperatorBase::GetSingleArgument("training_mode", 1)) { + : IDEEPConvPoolOpBase(operator_def, ws) { CAFFE_ENFORCE( (dilation_h() == 1) && (dilation_w() == 1), "Pooling op does not support dilation right now."); @@ -20,6 +18,10 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase { pad_l() < kernel_w() && pad_r() < kernel_w(), "Pad should be smaller than kernel."); } + + bool training_mode = OperatorBase::GetSingleArgument("training_mode", 1); + pk_ = training_mode ? iprop::forward_training : iprop::forward_inference; + // Figure out the pooling descriptor. if (operator_def.type().substr(0, 7) == "MaxPool") { algo_ = ialgo::pooling_max; @@ -35,18 +37,23 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase { auto& X = Input(INPUT); auto* Y = Output(OUTPUT); auto Y_dims = CalcOutputDims(X, X.get_dim(1)); - mkldnn::prop_kind pk = training_mode_ ? - mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_inference; - ideep::pooling_forward::compute(X, Y_dims, *Y, - stride_, kernel_, pad_tl(), pad_br(), algo_, pk); + if (cached_X_descriptor_ != X.get_descriptor()) { + op_key_.clear(); + cached_X_descriptor_ = X.dup_descriptor(); + } + + ideep::pooling_forward::compute(op_key_, X, Y_dims, *Y, + stride_, kernel_, pad_tl(), pad_br(), algo_, pk_); return true; } private: + iprop pk_; ialgo algo_; - bool training_mode_; + ikey op_key_; + itensor::descriptor cached_X_descriptor_; INPUT_TAGS(INPUT); OUTPUT_TAGS(OUTPUT); diff --git a/caffe2/ideep/operators/utility_ops.cc b/caffe2/ideep/operators/utility_ops.cc index f312c6f..e1dfece 100644 --- a/caffe2/ideep/operators/utility_ops.cc +++ b/caffe2/ideep/operators/utility_ops.cc @@ -19,7 +19,7 @@ class CopyCPUToIDEEPOp final : public IDEEPOperator { Y->Reset(new itensor()); Y->GetMutable()->resize(src_dims, itensor::data_type::f32); } - Y->GetMutable()->reorder_from( + Y->GetMutable()->feed_from( src_dims, itensor::data_type::f32, X.raw_data()); return true; } @@ -61,7 +61,7 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator { } auto* Y = OperatorBase::OutputTensor(0, dims, at::dtype().device(CPU)); - X.reorder_to(Y->template mutable_data()); + X.to_public(Y->template mutable_data()); } else { CAFFE_THROW("Unsupported ideep type: ", X.get_data_type()); } diff --git a/caffe2/ideep/utils/ideep_operator.h b/caffe2/ideep/utils/ideep_operator.h index e21aa56..efc5a3b 100644 --- a/caffe2/ideep/utils/ideep_operator.h +++ b/caffe2/ideep/utils/ideep_operator.h @@ -16,6 +16,8 @@ C10_DECLARE_REGISTRY( C10_REGISTER_CREATOR(IDEEPOperatorRegistry, key, __VA_ARGS__) #define REGISTER_IDEEP_OPERATOR(name, ...) \ C10_REGISTER_CLASS(IDEEPOperatorRegistry, name, __VA_ARGS__) +#define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \ + C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) #define REGISTER_IDEEP_OPERATOR_STR(str_name, ...) \ C10_REGISTER_TYPED_CLASS(IDEEPOperatorRegistry, str_name, __VA_ARGS__) #define REGISTER_IDEEP_COMPARE_OPERATOR(Op) \ @@ -27,8 +29,6 @@ C10_DECLARE_REGISTRY( Op##Functor, \ FixedType>>) -#define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) // IDEEPOperator is the base scaffolding of the operators that uses IDEEP. It // provides a few operators that are useful to IDEEP specific implementations. @@ -39,8 +39,6 @@ class IDEEPOperator : public OperatorBase { context_(operator_def.device_option()), order_(StringToStorageOrder( OperatorBase::GetSingleArgument("order", "NCHW"))) { - OPERATOR_NEEDS_FEATURE( - order_ == StorageOrder::NCHW, "Unsupported storage order."); } virtual ~IDEEPOperator() {} @@ -119,4 +117,19 @@ class IDEEPOperator : public OperatorBase { : IDEEPOperator(operator_def, ws) {} \ virtual ~name() {} +// Convert zero_point scales to min_max scales +// NOTE: +// The scales in operator is saved in FBGEMM format, +// while FBGEMM scales are the reciprocals of MKL-DNN scales. +// This function is provided to convert scales from FBGEMM to MKL-DNN +inline ideep::scale_t ConvertScales( + const std::vector scales_z) { + ideep::scale_t scales (scales_z); + for (auto it = scales.begin(); it != scales.end(); it++) { + *it = 1.0f / *it; + } + return scales; +} + + } // namespace caffe2 diff --git a/caffe2/opt/optimize_ideep.cc b/caffe2/opt/optimize_ideep.cc index 05bce30..0d6312f 100644 --- a/caffe2/opt/optimize_ideep.cc +++ b/caffe2/opt/optimize_ideep.cc @@ -3,6 +3,7 @@ #include "caffe2/opt/fusion.h" #ifdef CAFFE2_USE_MKLDNN +#include #include "caffe2/ideep/ideep_utils.h" #endif @@ -78,7 +79,7 @@ bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) { } bool shouldFuseConv(const repr::Conv& conv) { - return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false; + return isOnIdeepDevice(conv); } void removeStopGradientForInference(repr::NNModule* nn) { @@ -110,10 +111,6 @@ void removeStopGradientForInference(repr::NNModule* nn) { } void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) { - // Fusion types: - // FUSION_CONV_RELU = 1 - // FUSION_CONV_SUM = 2 - // FUSION_CONV_SUM_RELU = 3 auto conv = repr::nn::get(convNode); auto annotation = conv->getMutableAnnotation(); if (!annotation || !isa(annotation)) { @@ -126,19 +123,18 @@ void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) { } if (op->type() == "ConvFusion") { - CAFFE_ENFORCE(fusion_type == 1, "Invalid nest fusion"); + CAFFE_ENFORCE(fusion_type == FUSION_CONV_RELU, "Invalid nest fusion"); for (auto& arg : *op->mutable_arg()) { if (arg.name() == "fusion_type") { - // Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU - CAFFE_ENFORCE(arg.i() == 2, "Invalid nest fusion"); - arg.set_i(3); + CAFFE_ENFORCE(arg.i() == FUSION_CONV_SUM, "Invalid nest fusion"); + arg.set_i(FUSION_CONV_SUM_RELU); return; } } return; } - CAFFE_ENFORCE(fusion_type < 3, "Invalid fusion type"); + CAFFE_ENFORCE_LT(fusion_type, FUSION_CONV_SUM_RELU, "Invalid fusion type"); op->set_type("ConvFusion"); auto* arg = op->add_arg(); arg->set_name("fusion_type"); @@ -224,7 +220,7 @@ bool fuseConvBNAndAffChHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) continue; \ } \ name##Tensor.resize(name->get_dims(), name->get_data_type()); \ - name##Tensor.reorder_from(*name); \ + name##Tensor.feed_from(*name); \ CAFFE_ENFORCE( \ name##Tensor.is_public_format(), #name " not with public format"); \ name##Data = static_cast(name##Tensor.get_data_handle()); \ @@ -263,8 +259,8 @@ bool fuseConvBNAndAffChHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) } } - filter->reorder_from(filterTensor); - biasConv->reorder_from(biasConvTensor); + filter->feed_from(filterTensor); + biasConv->feed_from(biasConvTensor); nn->dataFlow.replaceNode(convOutput, bnOrAffChOutput); nn->dataFlow.deleteNode(bnOrAffChNode); @@ -282,6 +278,7 @@ void fuseConvBNAndAffChForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { } void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { + CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo"); // Assume the order of nodes from getMutableNodes conforms to // the original topo order of operators auto allNodes = nn->dataFlow.getMutableNodes(); @@ -342,11 +339,16 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { } auto conv = repr::nn::get(convNode); - if (!shouldFuseConv(*conv)) { + if (!isOnIdeepDevice(*conv)) { LOG(WARNING) << "Not a IDEEP operator"; continue; } + if (conv->getGroup() > 1 && !cpuinfo_has_x86_avx512f()) { + LOG(WARNING) << "Not support conv sum fusion with grouped filter"; + continue; + } + auto convOutput = repr::nn::getOutputs(convNode).front(); repr::NNGraph::NodeRef sumInputX = (sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]); @@ -366,8 +368,7 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { auto sumOutput = repr::nn::getOutputs(sumNode).front(); nn->dataFlow.replaceNode(sumOutput, newOutput); - // 2 means FUSION_CONV_SUM - resetConvForFusion(convNode, 2); + resetConvForFusion(convNode, FUSION_CONV_SUM); nn->dataFlow.createEdge(sumInputX, convNode); nn->dataFlow.createEdge(convNode, newOutput); @@ -405,8 +406,8 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) { bool enforce_inplace = false; for (const auto& arg : op.arg()) { - // Only check FUSION_SUM & FUSION_SUM_RELU - if (arg.name() == "fusion_type" && (arg.i() == 2 || arg.i() == 3)) { + if (arg.name() == "fusion_type" + && (arg.i() == FUSION_CONV_SUM || arg.i() == FUSION_CONV_SUM_RELU)) { enforce_inplace = true; break; } diff --git a/caffe2/python/ideep/convfusion_op_test.py b/caffe2/python/ideep/convfusion_op_test.py index 8c40be8..1e1a3ce 100644 --- a/caffe2/python/ideep/convfusion_op_test.py +++ b/caffe2/python/ideep/convfusion_op_test.py @@ -256,6 +256,7 @@ class ConvFusionTest(hu.HypothesisTestCase): workspace.SwitchWorkspace(old_ws_name) + @given(stride=st.integers(1, 3), pad=st.integers(0, 3), kernel=st.integers(3, 5), @@ -410,6 +411,113 @@ class ConvFusionTest(hu.HypothesisTestCase): pad=st.integers(0, 3), kernel=st.integers(3, 5), size=st.integers(8, 20), + input_channels=st.integers(7, 17), + output_channels=st.integers(5, 15), + batch_size=st.integers(1, 3), + use_bias=st.booleans(), + group=st.integers(2, 5), + **mu.gcs) + def test_convolution_grouped_sum_relu_fusion(self, stride, pad, kernel, size, + input_channels, output_channels, + batch_size, use_bias, group, gc, dc): + conv_S0 = core.CreateOperator( + "Conv", + ["SX0", "Sw0", "Sb0"] if use_bias else ["SX0", "Sw0"], + ["S0"], + stride=stride, + pad=pad, + kernel=kernel, + group=group, + device_option=dc[0] + ) + conv = core.CreateOperator( + "Conv", + ["X0", "w0", "b0"] if use_bias else ["X0", "w0"], + ["Y0"], + stride=stride, + pad=pad, + kernel=kernel, + group=group, + device_option=dc[0] + ) + sum = core.CreateOperator( + "Sum", + ["S0", "Y0"], + ["S0"], + device_option=dc[0] + ) + relu = core.CreateOperator( + "Relu", + ["S0"], + ["S0"], + device_option=dc[0] + ) + + SX = np.random.rand( + batch_size, input_channels * group, size, size).astype(np.float32) - 0.5 + Sw = np.random.rand( + output_channels * group, input_channels, kernel, kernel) \ + .astype(np.float32) - 0.5 + Sb = np.random.rand(output_channels * group).astype(np.float32) - 0.5 + X = np.random.rand( + batch_size, input_channels * group, size, size).astype(np.float32) - 0.5 + w = np.random.rand( + output_channels * group, input_channels, kernel, kernel) \ + .astype(np.float32) - 0.5 + b = np.random.rand(output_channels * group).astype(np.float32) - 0.5 + + old_ws_name = workspace.CurrentWorkspace() + workspace.SwitchWorkspace("_device_check_", True) + workspace.FeedBlob('SX0', SX, dc[0]) + workspace.FeedBlob('Sw0', Sw, dc[0]) + workspace.FeedBlob('Sb0', Sb, dc[0]) + workspace.FeedBlob('X0', X, dc[0]) + workspace.FeedBlob('w0', w, dc[0]) + workspace.FeedBlob('b0', b, dc[0]) + workspace.RunOperatorOnce(conv_S0) + workspace.RunOperatorOnce(conv) + workspace.RunOperatorOnce(sum) + workspace.RunOperatorOnce(relu) + S0 = workspace.FetchBlob('S0') + + workspace.ResetWorkspace() + old_net = caffe2_pb2.NetDef() + conv_S0_old = caffe2_pb2.OperatorDef() + conv_S0_old.CopyFrom(conv_S0) + conv_S0_old.device_option.CopyFrom(dc[1]) + conv_old = caffe2_pb2.OperatorDef() + conv_old.CopyFrom(conv) + conv_old.device_option.CopyFrom(dc[1]) + sum_old = caffe2_pb2.OperatorDef() + sum_old.CopyFrom(sum) + sum_old.device_option.CopyFrom(dc[1]) + relu_old = caffe2_pb2.OperatorDef() + relu_old.CopyFrom(relu) + relu_old.device_option.CopyFrom(dc[1]) + old_net.op.extend([conv_S0_old, conv_old, sum_old, relu_old]) + workspace.FeedBlob('SX0', SX, dc[1]) + workspace.FeedBlob('Sw0', Sw, dc[1]) + workspace.FeedBlob('Sb0', Sb, dc[1]) + workspace.FeedBlob('X0', X, dc[1]) + workspace.FeedBlob('w0', w, dc[1]) + workspace.FeedBlob('b0', b, dc[1]) + net = core.Net("net") + net.Proto().CopyFrom(old_net) + optimizeForIDEEP(net) + workspace.RunNetOnce(net.Proto()) + S2 = workspace.FetchBlob('S0') + if not np.allclose(S0, S2, atol=0.01, rtol=0.01): + print(S2.flatten()) + print(S0.flatten()) + print(np.max(np.abs(S2 - S0))) + self.assertTrue(False) + + workspace.SwitchWorkspace(old_ws_name) + + @given(stride=st.integers(1, 3), + pad=st.integers(0, 3), + kernel=st.integers(3, 5), + size=st.integers(8, 20), input_channels=st.integers(1, 16), output_channels=st.integers(1, 16), batch_size=st.integers(1, 3), diff --git a/caffe2/python/pybind_state_ideep.cc b/caffe2/python/pybind_state_ideep.cc index ff4971e..4460e0d 100644 --- a/caffe2/python/pybind_state_ideep.cc +++ b/caffe2/python/pybind_state_ideep.cc @@ -59,13 +59,13 @@ public: (atensor.get_nelems() == 0 || atensor.get_data_handle() != nullptr), "Trying to fetch uninitialized tensor"); - const int numpy_type = CaffeToNumpyType(type_transform(atensor)); + // NOTE: Only support float so far. + const int numpy_type = NPY_FLOAT; CAFFE_ENFORCE( numpy_type != -1, "Unsupported ideep memory data type? This usually should not happen " "since ideep memory usually only do float and double."); itensor::dims dims = atensor.get_public_format_dims(); - std::vector npy_dims(dims.begin(), dims.end()); result.copied = force_copy || atensor.need_reorder(); @@ -86,7 +86,7 @@ public: } if (result.copied) { - atensor.reorder_to(outPtr); + atensor.to_public(outPtr); } return result; @@ -144,7 +144,7 @@ public: if (tensor->get_dims() != adims || type != tensor->get_data_type()) { tensor->resize(adims, type); } - tensor->reorder_from(adims, type, + tensor->feed_from(adims, type, static_cast(PyArray_DATA(array))); } #else -- 2.7.4