Upgrade mkldnn-bridge for dnnlowp support (#16308)
authorGu, Jinghui <jinghui.gu@intel.com>
Wed, 3 Apr 2019 17:29:19 +0000 (10:29 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 19:47:17 +0000 (12:47 -0700)
Summary:
The mkldnn-bridge is upgraded in this PR to support DNNLOWP operators.
Meanwhile, APIs have been updated in caffe2 to use latest version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16308

Differential Revision: D14697018

Pulled By: yinghai

fbshipit-source-id: ca952589098accb08295fd5aa92924c61e74d69c

12 files changed:
caffe2/ideep/ideep_utils.h
caffe2/ideep/operators/conv_fusion_op.cc [deleted file]
caffe2/ideep/operators/conv_op.cc
caffe2/ideep/operators/conv_pool_base_op.h
caffe2/ideep/operators/conv_transpose_op.cc
caffe2/ideep/operators/operator_fallback_ideep.h
caffe2/ideep/operators/pool_op.cc
caffe2/ideep/operators/utility_ops.cc
caffe2/ideep/utils/ideep_operator.h
caffe2/opt/optimize_ideep.cc
caffe2/python/ideep/convfusion_op_test.py
caffe2/python/pybind_state_ideep.cc

index db4195c..11adf8c 100644 (file)
@@ -12,16 +12,41 @@ namespace caffe2 {
 enum ConvAlgorithm {
   CONV_ALGORITHM_AUTO = 0,
   CONV_ALGORITHM_WINOGRAD = 1,
-  CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1
+  CONV_ALGORITHM_MAX
+};
+
+enum FusionType {
+  FUSION_UNKNOWN = 0,
+  FUSION_CONV_RELU = 1,
+  FUSION_CONV_SUM = 2,
+  FUSION_CONV_SUM_RELU = 3,
+  FUSION_MAX
 };
 
 #define USE_IDEEP_DEF_ALIASES()                                                \
+  /* the hash key of cahced operator generated by iDEEP  */                    \
+  using ikey = ideep::key_t;                                                   \
+  /* the tensor type created/handled by iDEEP  */                              \
   using itensor = ideep::tensor;                                               \
+  /* the date layout of iDEEP tensor */                                        \
   using iformat = ideep::format;                                               \
+  /* the scales for iDEEP tensor with different data type */                   \
+  using iscale = ideep::scale_t;                                               \
+  /* the detial algorithm for iDEEP operators, e.g. winograd */                \
   using ialgo = ideep::algorithm;                                              \
+  /* the kind of propagation for iDEEP operators, e.g. forward, training */    \
   using iprop = ideep::prop_kind;                                              \
+  /* the kind of low precision operators, e.g. signed/unsigned activation */   \
+  using ilowp_kind = ideep::lowp_kind;                                         \
+  /* the kind of padding, usually set as zero padding */                       \
   using ipadding = ideep::padding_kind;                                        \
+  /* the data type of iDEEP tensor, e.g. f32, u8, s8 */                        \
+  using idtype = ideep::tensor::data_type;                                     \
+  /* the descriptor of iDEEP tensor */                                         \
+  using itdesc = ideep::tensor::descriptor;                                    \
+  /* the attribute for operator to describe the details of inputs&fusion */    \
   using iattr = ideep::descriptor_group::attr_t;                               \
+  /* the detail flags for batch normalization */                               \
   using ibn_flag = ideep::batch_normalization_flag;
 
 } // namespace caffe2
diff --git a/caffe2/ideep/operators/conv_fusion_op.cc b/caffe2/ideep/operators/conv_fusion_op.cc
deleted file mode 100644 (file)
index ff23991..0000000
+++ /dev/null
@@ -1,219 +0,0 @@
-#include <caffe2/ideep/operators/conv_pool_base_op.h>
-
-namespace caffe2 {
-
-class IDEEPConvFusionOp final : public IDEEPConvPoolOpBase {
- public:
-  USE_IDEEP_DEF_ALIASES();
-  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
-
-  enum FusionType {
-    FUSION_UNKNOWN = 0,
-    FUSION_CONV_RELU = 1,
-    FUSION_CONV_SUM = 2,
-    FUSION_CONV_SUM_RELU = 3,
-    FUSION_MAX = FUSION_CONV_SUM_RELU + 1,
-  };
-
-  IDEEPConvFusionOp(const OperatorDef& operator_def, Workspace* ws)
-      : IDEEPConvPoolOpBase(operator_def, ws),
-        fusion_type_(static_cast<FusionType>(
-            OperatorBase::GetSingleArgument<int>("fusion_type", 0))),
-        training_mode_(
-            OperatorBase::GetSingleArgument<int>("training_mode", 0)),
-        conv_algorithm_(
-            OperatorBase::GetSingleArgument<int>("conv_algorithm", CONV_ALGORITHM_AUTO)) {
-    OPERATOR_NEEDS_FEATURE(
-        pad_l() == pad_r() && pad_t() == pad_b(),
-        "Uneven padding not supported.");
-    OPERATOR_NEEDS_FEATURE(group_ == 1, "Group not supported.");
-    OPERATOR_NEEDS_FEATURE(
-        fusion_type_ > FUSION_UNKNOWN && fusion_type_ < FUSION_MAX,
-        "Undefined Conv fusion type.",
-        fusion_type_);
-
-    // Check kernel only if we are doing conv. The reason is that a
-    // few other ops, like PadImage, are also using this base class. We really
-    // need to clean this up.
-    for (int dim = 0; dim < kernel_.size(); ++dim) {
-      CAFFE_ENFORCE_GE(pads_[dim], 0);
-      CAFFE_ENFORCE_GE(pads_[kernel_.size() + dim], 0);
-      CAFFE_ENFORCE(
-          kernel_[dim],
-          "If you are doing convolution, you will need to set "
-          "explicitly the kernel size.");
-    }
-  }
-  ~IDEEPConvFusionOp() override {}
-
-  bool RunOnDeviceWithOrderNCHW() override {
-    const auto& X = Input(INPUT_X);
-    const auto& filter = Input(FILTER);
-    auto* Y = Output(OUTPUT);
-    auto Y_dims_conv = CalcOutputDims(X, filter.get_dim(0));
-    auto attr = [this]() {
-      return (fusion_type_ == FUSION_CONV_RELU)
-          ? iattr::fuse_relu()
-          : ((fusion_type_ == FUSION_CONV_SUM)
-                 ? iattr::fuse_sum()
-                 : ((fusion_type_ == FUSION_CONV_SUM_RELU) ? iattr::residual()
-                                                           : iattr()));
-    };
-    auto last_input = [this]() {
-      return (fusion_type_ == FUSION_CONV_RELU) ? BIAS_OR_INPUT_S : INPUT_S;
-    };
-
-    CAFFE_ENFORCE(4 == X.ndims());
-    CAFFE_ENFORCE(4 == filter.ndims());
-    CAFFE_ENFORCE(filter.get_dim(2) == kernel_h());
-    CAFFE_ENFORCE(filter.get_dim(3) == kernel_w());
-    CAFFE_ENFORCE(
-        X.get_dim(1) == filter.get_dim(1) * group_,
-        "Convolution fusion op: input channels does not match: "
-        "# of input channels ",
-        X.get_dim(1),
-        " is not equal to kernel channels * group:",
-        filter.get_dim(1),
-        "*",
-        group_);
-
-    ideep::algorithm aalgorithm = ideep::algorithm::convolution_direct;
-    if (conv_algorithm_ == CONV_ALGORITHM_WINOGRAD) {
-      aalgorithm = ideep::algorithm::convolution_winograd;
-    }
-
-    bool weights_changed =
-        (cached_weights_descriptor_ != filter.get_descriptor());
-    if (weights_changed && !training_mode_) {
-      cached_weights_descriptor_ = filter.get_descriptor();
-      filter_ = filter;
-      auto expected_descriptor =
-          ideep::convolution_forward::expected_weights_descriptor(
-              filter.get_dims());
-      if (filter_.get_descriptor() != expected_descriptor) {
-        filter_.init<ideep::utils::allocator, ideep::convolution_forward>(
-            expected_descriptor);
-        ideep::reorder::compute(filter, filter_);
-      }
-    }
-
-    if (InputSize() > last_input()) {
-      ideep::convolution_forward::compute(
-          X,
-          training_mode_ ? filter : filter_,
-          Input(BIAS_OR_INPUT_S),
-          Y_dims_conv,
-          *Y,
-          stride_,
-          dilation_,
-          pad_tl(),
-          pad_br(),
-          group_,
-          attr(),
-          aalgorithm);
-    } else {
-      ideep::convolution_forward::compute(
-          X,
-          training_mode_ ? filter : filter_,
-          Y_dims_conv,
-          *Y,
-          stride_,
-          dilation_,
-          pad_tl(),
-          pad_br(),
-          group_,
-          attr(),
-          aalgorithm);
-    }
-
-    if (fusion_type_ != FUSION_CONV_RELU) {
-      CAFFE_ENFORCE(
-          Y == &(Input(InputSize() - 1)),
-          "Convolution fusion op: InPlace is enforced for sum fusion.");
-    }
-
-    return true;
-  }
-
- private:
-  FusionType fusion_type_;
-  bool training_mode_;
-  int conv_algorithm_;
-  ideep::tensor filter_;
-  ideep::tensor::descriptor cached_weights_descriptor_;
-
-  INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
-  OUTPUT_TAGS(OUTPUT);
-};
-
-REGISTER_IDEEP_OPERATOR(ConvFusion, IDEEPConvFusionOp);
-
-const char* kConvFusionDoc = R"DOC(
-Note that other parameters, such as the stride and
-kernel size, or the pads' sizes in each direction are not necessary for input
-because they are provided by the ConvPoolOpBase operator. Various dimension
-checks are done implicitly, and the sizes are specified in the Input docs for
-this operator. As is expected, the filter is convolved with a subset of the
-image and the bias is added; this is done throughout the image data and the
-output is computed. As a side note on the implementation layout:
-conv_op_impl.h is the templated implementation of the conv_op.h file, which is
-why they are separate files.
-)DOC";
-
-std::function<void(OpSchema&)> ConvFusionDocGenerator(const char* dim) {
-  return [=](OpSchema& schema) {
-    string doc = R"DOC(
-The convolution fusion operator consumes an input vector, a {dim}filter blob,
-a bias blob and another input vector and computes the output. This operator
-gives the chance to fuse the ReLU or element-wise Sum with a convolution
-operator. {conv_fusion_doc})DOC";
-    c10::ReplaceAll(doc, "{dim}", dim);
-    c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc);
-    schema.SetDoc(doc);
-    schema.Input(
-        0,
-        "X",
-        "Input data blob from previous layer; has size (N x C x H x W), "
-        "where N is the batch size, C is the number of channels, "
-        "and H and W are the height and width. Note that this is for the NCHW "
-        "usage. On the other hand, the NHWC Op has a different set of "
-        "dimension constraints. ");
-    schema.Input(
-        1,
-        "filter",
-        "The filter blob that will be used in the "
-        "convolutions; has size (M x C x kH x kW), where C is the number of "
-        "channels, and kH and kW are the height and width of the kernel.");
-    schema.Input(
-        2,
-        "bias",
-        "The 1D bias blob that is added through the "
-        "convolution; has size (M).");
-    schema.Input(
-        3,
-        "S",
-        "Input data blob for element-wise Sum fusion from previous layer; "
-        "has the same size of convolution output. Its input index should "
-        "be 2 if no bias for this convolution, and it MUST be inplace with "
-        "output Y.");
-    schema.Output(
-        0,
-        "Y",
-        "Output data blob that contains the result of the "
-        "convolution fusion. The output dimensions are functions of the kernel "
-        "size, stride size, and pad lengths."
-        "");
-  };
-}
-
-OPERATOR_SCHEMA(ConvFusion)
-    .NumInputs(2, 4)
-    .NumOutputs(1)
-    .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
-    .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
-        ConvPoolOpBase<CPUContext>::CostInferenceForConv))
-    .Arg("fusion_type", "Which fusion type is used")
-    .AllowInplace({{2, 0}, {3, 0}})
-    .FillUsing(ConvFusionDocGenerator(""));
-
-} // namespace caffe2
index 4ecc334..d83c995 100644 (file)
 
 namespace caffe2 {
 
-class IDEEPConvOp final : public IDEEPConvPoolOpBase {
+class IDEEPConvOp : public IDEEPConvPoolOpBase {
  public:
   USE_IDEEP_DEF_ALIASES();
   USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
 
   IDEEPConvOp(const OperatorDef& operator_def, Workspace* ws)
-      : IDEEPConvPoolOpBase(operator_def, ws),
-        training_mode_(
-            OperatorBase::GetSingleArgument<int>("training_mode", 0)),
-        conv_algorithm_(
-            OperatorBase::GetSingleArgument<int>("conv_algorithm", CONV_ALGORITHM_AUTO)) {
+      : IDEEPConvPoolOpBase(operator_def, ws) {
+    OPERATOR_NEEDS_FEATURE(
+        order_ == StorageOrder::NCHW, "Unsupported storage order.");
     OPERATOR_NEEDS_FEATURE(
         pad_l() == pad_r() && pad_t() == pad_b(),
         "Uneven padding not supported.");
+
+    fusion_type_ = FUSION_UNKNOWN;
+    last_input_ = BIAS_OR_INPUT_S;
+
+    training_mode_ = OperatorBase::GetSingleArgument<int>("training_mode", 0);
+    pk_ = training_mode_ ? iprop::forward_training : iprop::forward_inference;
+
+    algo_ = ialgo::convolution_direct;
+    auto conv_algorithm = OperatorBase::GetSingleArgument<int>(
+        "conv_algorithm", CONV_ALGORITHM_AUTO);
+    if (conv_algorithm == CONV_ALGORITHM_WINOGRAD) {
+      algo_ = ialgo::convolution_winograd;
+    }
   }
-  ~IDEEPConvOp() override {}
+  virtual ~IDEEPConvOp() {}
 
   bool RunOnDeviceWithOrderNCHW() override {
-    const auto& X = Input(INPUT);
+    const auto& X = Input(INPUT_X);
     const auto& filter = Input(FILTER);
     auto* Y = Output(OUTPUT);
-    auto Y_dims = CalcOutputDims(X, filter.get_dim(0));
+    auto grouped = filter.is_grouped() ? 1 : 0;
+    auto Y_dims_conv = CalcOutputDims(
+        X,
+        grouped ? (filter.get_dim(0) * filter.get_dim(1)) : filter.get_dim(0));
 
     CAFFE_ENFORCE(4 == X.ndims());
-    CAFFE_ENFORCE(4 == filter.ndims());
-    CAFFE_ENFORCE(filter.get_dim(2) == kernel_h());
-    CAFFE_ENFORCE(filter.get_dim(3) == kernel_w());
+    CAFFE_ENFORCE(4 == filter.ndims() || (grouped && (group_ > 1)));
+    CAFFE_ENFORCE_EQ(filter.get_dim(2 + grouped), kernel_h());
+    CAFFE_ENFORCE_EQ(filter.get_dim(3 + grouped), kernel_w());
     CAFFE_ENFORCE(
-        X.get_dim(1) == filter.get_dim(1) * group_,
+        X.get_dim(1) == filter.get_dim(1 + grouped) * group_,
         "Convolution op: input channels does not match: # of input channels ",
         X.get_dim(1),
         " is not equal to kernel channels * group:",
-        filter.get_dim(1),
+        filter.get_dim(1 + grouped),
         "*",
         group_);
 
-    ideep::algorithm aalgorithm = ideep::algorithm::convolution_direct;
-    if (conv_algorithm_ == CONV_ALGORITHM_WINOGRAD) {
-      aalgorithm = ideep::algorithm::convolution_winograd;
-    }
-
     bool weights_changed =
         (cached_weights_descriptor_ != filter.get_descriptor());
     if (weights_changed && !training_mode_) {
-      cached_weights_descriptor_ = filter.get_descriptor();
-      auto filter_in = filter;
+      op_key_.clear();
+      cached_weights_descriptor_ = filter.dup_descriptor();
+      auto filter_in = filter.as_weights();
       filter_in.make_group(group_);
+
       auto expected_descriptor =
           ideep::convolution_forward::expected_weights_descriptor(
               filter_in.get_dims(),
-              filter_in.get_data_type(),
+              idtype::f32,
               stride_,
               pad_tl(),
               pad_br(),
               dilation_,
               group_,
-              aalgorithm);
-      filter_.init<ideep::utils::allocator, ideep::convolution_forward>(
-          expected_descriptor);
-      ideep::reorder::compute(filter_in, filter_);
+              algo_,
+              pk_,
+              idtype::f32,
+              X.get_dims());
+      if (filter_in.get_descriptor() != expected_descriptor) {
+        filter_.init(expected_descriptor);
+        filter_.feed_from(filter_in);
+      } else {
+        filter_ = filter_in;
+      }
+    }
+
+    if (cached_X_descriptor_ != X.get_descriptor()) {
+      op_key_.clear();
+      cached_X_descriptor_ = X.dup_descriptor();
     }
 
-    // NB: actually, in the case when `group_ > 1`, IDEEP will create
-    // an itermediate tensor for each run below. However, this tensor is merely
-    // a view of of the weights and there is no actual data copy, so I'll let it
-    // go now. If we encounter performance surprise when convoluting with group
-    // > 1, this is the first place to check and we need to do the same cache
-    // trick as above
-    if (InputSize() > BIAS) {
+    if (InputSize() > last_input_) {
       ideep::convolution_forward::compute(
+          op_key_,
           X,
           training_mode_ ? filter : filter_,
-          Input(BIAS),
-          Y_dims,
+          Input(BIAS_OR_INPUT_S),
+          Y_dims_conv,
           *Y,
           stride_,
           dilation_,
           pad_tl(),
           pad_br(),
           group_,
-          ideep::descriptor_group::attr_t(),
-          aalgorithm);
+          dummy_scale_,
+          dummy_scale_,
+          dummy_scale_,
+          attr_,
+          algo_,
+          pk_);
     } else {
       ideep::convolution_forward::compute(
+          op_key_,
           X,
           training_mode_ ? filter : filter_,
-          Y_dims,
+          Y_dims_conv,
           *Y,
           stride_,
           dilation_,
           pad_tl(),
           pad_br(),
           group_,
-          ideep::descriptor_group::attr_t(),
-          aalgorithm);
+          dummy_scale_,
+          dummy_scale_,
+          dummy_scale_,
+          attr_,
+          algo_,
+          pk_);
+    }
+
+    if (fusion_type_ == FUSION_CONV_SUM
+        && fusion_type_ == FUSION_CONV_SUM_RELU) {
+      CAFFE_ENFORCE_EQ(Y,  &(Input(InputSize() - 1)),
+          "Convolution fusion op: InPlace is enforced for sum fusion.");
     }
 
     return true;
   }
 
- private:
-  INPUT_TAGS(INPUT, FILTER, BIAS);
+ protected:
+  iprop pk_;
+  ialgo algo_;
+  iattr attr_;
+  ikey op_key_;
+  int last_input_;
+  bool training_mode_;
+  FusionType fusion_type_;
+  itensor filter_;
+  iscale dummy_scale_;
+  itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
+
+  INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
   OUTPUT_TAGS(OUTPUT);
+};
 
-  bool training_mode_;
-  int conv_algorithm_;
-  ideep::tensor filter_;
-  ideep::tensor::descriptor cached_weights_descriptor_;
+class IDEEPConvFusionOp final : public IDEEPConvOp {
+ public:
+  USE_IDEEP_DEF_ALIASES();
+  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
+
+  IDEEPConvFusionOp(const OperatorDef& operator_def, Workspace* ws)
+      : IDEEPConvOp(operator_def, ws) {
+    CAFFE_ENFORCE(OperatorBase::HasArgument("fusion_type"),
+          "You should specify the fusion type");
+    fusion_type_ = static_cast<FusionType>(
+        OperatorBase::GetSingleArgument<int>("fusion_type", FUSION_UNKNOWN));
+    OPERATOR_NEEDS_FEATURE(
+        fusion_type_ > FUSION_UNKNOWN && fusion_type_ < FUSION_MAX,
+        "Undefined Conv fusion type.",
+        fusion_type_);
+
+    switch (fusion_type_) {
+      case FUSION_CONV_RELU:
+        attr_ = iattr::fuse_relu();
+        last_input_ = BIAS_OR_INPUT_S;
+        break;
+      case FUSION_CONV_SUM:
+        attr_ = iattr::fuse_sum();
+        last_input_ = INPUT_S;
+        break;
+      case FUSION_CONV_SUM_RELU:
+        attr_ = iattr::residual();
+        last_input_ = INPUT_S;
+        break;
+      default:
+        CAFFE_THROW("Unsupported conv fusion type!");
+    }
+  }
+  virtual ~IDEEPConvFusionOp() {}
 };
 
+const char* kConvFusionDoc = R"DOC(
+Note that other parameters, such as the stride and
+kernel size, or the pads' sizes in each direction are not necessary for input
+because they are provided by the ConvPoolOpBase operator. Various dimension
+checks are done implicitly, and the sizes are specified in the Input docs for
+this operator. As is expected, the filter is convolved with a subset of the
+image and the bias is added; this is done throughout the image data and the
+output is computed. As a side note on the implementation layout:
+conv_op_impl.h is the templated implementation of the conv_op.h file, which is
+why they are separate files.
+)DOC";
+
+std::function<void(OpSchema&)> ConvFusionDocGenerator(const char* dim) {
+  return [=](OpSchema& schema) {
+    string doc = R"DOC(
+The convolution fusion operator consumes an input vector, a {dim}filter blob,
+a bias blob and another input vector and computes the output. This operator
+gives the chance to fuse the ReLU or element-wise Sum with a convolution
+operator. {conv_fusion_doc})DOC";
+    c10::ReplaceAll(doc, "{dim}", dim);
+    c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc);
+    schema.SetDoc(doc);
+    schema.Input(
+        0,
+        "X",
+        "Input data blob from previous layer; has size (N x C x H x W), "
+        "where N is the batch size, C is the number of channels, "
+        "and H and W are the height and width. Note that this is for the NCHW "
+        "usage. On the other hand, the NHWC Op has a different set of "
+        "dimension constraints. ");
+    schema.Input(
+        1,
+        "filter",
+        "The filter blob that will be used in the "
+        "convolutions; has size (M x C x kH x kW), where C is the number of "
+        "channels, and kH and kW are the height and width of the kernel.");
+    schema.Input(
+        2,
+        "bias",
+        "The 1D bias blob that is added through the "
+        "convolution; has size (M).");
+    schema.Input(
+        3,
+        "S",
+        "Input data blob for element-wise Sum fusion from previous layer; "
+        "has the same size of convolution output. Its input index should "
+        "be 2 if no bias for this convolution, and it MUST be inplace with "
+        "output Y.");
+    schema.Output(
+        0,
+        "Y",
+        "Output data blob that contains the result of the "
+        "convolution fusion. The output dimensions are functions of the kernel "
+        "size, stride size, and pad lengths."
+        "");
+  };
+}
+
+OPERATOR_SCHEMA(ConvFusion)
+    .NumInputs(2, 4)
+    .NumOutputs(1)
+    .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
+    .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
+        ConvPoolOpBase<CPUContext>::CostInferenceForConv))
+    .Arg("fusion_type", "Which fusion type is used")
+    .AllowInplace({{2, 0}, {3, 0}})
+    .FillUsing(ConvFusionDocGenerator(""));
+
 class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
  public:
   USE_IDEEP_DEF_ALIASES();
@@ -131,7 +273,7 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
         "In order to backward propagate weights correctly, "
         "please set training_mode=1");
   }
-  ~IDEEPConvGradientOp() override {}
+  virtual ~IDEEPConvGradientOp() {}
 
   bool RunOnDeviceWithOrderNCHW() override {
     const auto& X = Input(INPUT);
@@ -190,6 +332,7 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
 };
 
 REGISTER_IDEEP_OPERATOR(Conv, IDEEPConvOp);
+REGISTER_IDEEP_OPERATOR(ConvFusion, IDEEPConvFusionOp);
 REGISTER_IDEEP_OPERATOR(ConvGradient, IDEEPConvGradientOp);
 
 } // namespace caffe2
index a576461..e4170a3 100644 (file)
@@ -11,10 +11,7 @@ namespace caffe2 {
 class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
  public:
   IDEEPConvPoolOpBase(const OperatorDef& operator_def, Workspace* ws)
-      : ConvPoolOpBase<IDEEPContext>(operator_def, ws) {
-    OPERATOR_NEEDS_FEATURE(
-        order_ == StorageOrder::NCHW, "Unsupported storage order.");
-  }
+     : ConvPoolOpBase<IDEEPContext>(operator_def, ws) {}
   virtual ~IDEEPConvPoolOpBase() {}
 
   inline const ideep::tensor& Input(int index) {
@@ -35,7 +32,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
   ideep::tensor::dims CalcOutputDims(
       const ideep::tensor& input,
       int output_channel) {
-    CAFFE_ENFORCE(input.get_descriptor().get_size() > 0);
+    CAFFE_ENFORCE_GT(input.get_size(), 0);
     ideep::tensor::dims output_dims;
     const auto input_dims = input.get_dims();
     std::vector<std::int64_t> input_Tdims(
@@ -43,7 +40,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
     InferOutputSize(
         input_Tdims,
         output_channel,
-        order_,
+        StorageOrder::NCHW, //order_,
         global_pooling_,
         legacy_pad_,
         dilation_,
index e05ee71..85981c2 100644 (file)
@@ -77,7 +77,7 @@ class IDEEPConvTransposeOp final : public IDEEPConvTransposeUnpoolBase {
         // we have to do explicit conversion here.
         filter_in.set_public_format(ideep::format::iohw);
         filter_.init(expected_descriptor);
-        ideep::reorder::compute(filter_in, filter_);
+        filter_.feed_from(filter_in);
       }
 
       // TODO: The code below works around correctness issues with particular input shapes
@@ -178,7 +178,7 @@ class IDEEPConvTransposeGradientOp final : public IDEEPConvTransposeUnpoolBase {
       // we have to do explicit conversion here.
       filter_in.set_public_format(ideep::format::iohw);
       filter_.init(expected_descriptor);
-      ideep::reorder::compute(filter_in, filter_);
+      filter_.feed_from(filter_in);
 
       // TODO: The code below works around correctness issues with particular input shapes
       // in MKL-DNN v0.17, will be removed with the fixes in MKL-DNN 0.18.
index 2dc3612..a022cac 100644 (file)
@@ -82,20 +82,29 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
 
   bool RunOnDevice() override {
     for (int i = 0; i < InputSize(); ++i) {
-      if (InputIsType<itensor>(i) &&
-          Input(i).get_data_type() == itensor::data_type::f32) {
+      if (InputIsType<itensor>(i)
+          && (Input(i).has_scale()
+            || Input(i).get_data_type() == idtype::f32)) {
         auto& input = Input(i);
         if (input_share_[i]) {
           local_input_blobs_[i]->Reset();
+          input_share_[i] = false;
         }
-        input_share_[i] = false;
         auto dtensor = BlobGetMutableTensor(local_input_blobs_[i], CPU);
         dtensor->Resize(input.get_dims());
-        if (input.is_public_format()) {
+        // If fallback from INT8, the public format of original input is nhwc.
+        // While the required format is nchw, need to reorder to nchw.
+        if (input.get_public_format() == iformat::nhwc) {
+          itensor temp_ten ({input.get_dims(), idtype::f32, iformat::nchw},
+              dtensor->template mutable_data<float>());
+          temp_ten.feed_from(input);
+        } else if (!input.need_reorder()) {
+          CAFFE_ENFORCE(!input.has_scale(),
+              "Incorrect invocation of get_data_handle");
           dtensor->ShareExternalPointer(
               static_cast<float*>(input.get_data_handle()));
         } else {
-          input.reorder_to(dtensor->template mutable_data<float>());
+          input.to_public(dtensor->template mutable_data<float>());
         }
       } else {
         VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
@@ -143,12 +152,14 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
         itensor::dims dst_dims (src_dims.begin(), src_dims.end());
         auto dtensor = dst->template GetMutable<itensor>();
         if (dtensor->get_dims() != dst_dims) {
-          dtensor->resize(dst_dims, itensor::data_type::f32);
+          dtensor->resize(dst_dims, idtype::f32);
         }
         if (output_inplace_[i]) {
-          dtensor->reorder_from(dst_dims, itensor::data_type::f32,
-                                const_cast<void*>(src.raw_data()));
+          dtensor->feed_from(dst_dims, idtype::f32,
+              const_cast<void*>(src.raw_data()));
         } else {
+          CAFFE_ENFORCE(!dtensor->has_scale(),
+              "Incorrect invocation of set_data_handle");
           dtensor->set_data_handle(const_cast<void *>(src.raw_data()));
         }
       } else {
index 45abb37..54baf18 100644 (file)
@@ -8,9 +8,7 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
   USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
 
   IDEEPPoolOp(const OperatorDef& operator_def, Workspace* ws)
-      : IDEEPConvPoolOpBase(operator_def, ws),
-        training_mode_(
-            OperatorBase::GetSingleArgument<int>("training_mode", 1)) {
+      : IDEEPConvPoolOpBase(operator_def, ws) {
     CAFFE_ENFORCE(
         (dilation_h() == 1) && (dilation_w() == 1),
         "Pooling op does not support dilation right now.");
@@ -20,6 +18,10 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
               pad_l() < kernel_w() && pad_r() < kernel_w(),
           "Pad should be smaller than kernel.");
     }
+
+    bool training_mode = OperatorBase::GetSingleArgument<int>("training_mode", 1);
+    pk_ = training_mode ? iprop::forward_training : iprop::forward_inference;
+
     // Figure out the pooling descriptor.
     if (operator_def.type().substr(0, 7) == "MaxPool") {
       algo_ = ialgo::pooling_max;
@@ -35,18 +37,23 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
     auto& X = Input(INPUT);
     auto* Y = Output(OUTPUT);
     auto Y_dims = CalcOutputDims(X, X.get_dim(1));
-    mkldnn::prop_kind pk = training_mode_ ?
-      mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_inference;
 
-    ideep::pooling_forward::compute(X, Y_dims, *Y,
-        stride_, kernel_, pad_tl(), pad_br(), algo_, pk);
+    if (cached_X_descriptor_ != X.get_descriptor()) {
+      op_key_.clear();
+      cached_X_descriptor_ = X.dup_descriptor();
+    }
+
+    ideep::pooling_forward::compute(op_key_, X, Y_dims, *Y,
+        stride_, kernel_, pad_tl(), pad_br(), algo_, pk_);
 
     return true;
   }
 
  private:
+  iprop pk_;
   ialgo algo_;
-  bool training_mode_;
+  ikey op_key_;
+  itensor::descriptor cached_X_descriptor_;
 
   INPUT_TAGS(INPUT);
   OUTPUT_TAGS(OUTPUT);
index f312c6f..e1dfece 100644 (file)
@@ -19,7 +19,7 @@ class CopyCPUToIDEEPOp final : public IDEEPOperator {
       Y->Reset(new itensor());
       Y->GetMutable<itensor>()->resize(src_dims, itensor::data_type::f32);
     }
-    Y->GetMutable<itensor>()->reorder_from(
+    Y->GetMutable<itensor>()->feed_from(
         src_dims, itensor::data_type::f32, X.raw_data());
     return true;
   }
@@ -61,7 +61,7 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator {
         }
         auto* Y =
             OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(CPU));
-        X.reorder_to(Y->template mutable_data<float>());
+        X.to_public(Y->template mutable_data<float>());
       } else {
         CAFFE_THROW("Unsupported ideep type: ", X.get_data_type());
       }
index e21aa56..efc5a3b 100644 (file)
@@ -16,6 +16,8 @@ C10_DECLARE_REGISTRY(
   C10_REGISTER_CREATOR(IDEEPOperatorRegistry, key, __VA_ARGS__)
 #define REGISTER_IDEEP_OPERATOR(name, ...) \
   C10_REGISTER_CLASS(IDEEPOperatorRegistry, name, __VA_ARGS__)
+#define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \
+  C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
 #define REGISTER_IDEEP_OPERATOR_STR(str_name, ...) \
   C10_REGISTER_TYPED_CLASS(IDEEPOperatorRegistry, str_name, __VA_ARGS__)
 #define REGISTER_IDEEP_COMPARE_OPERATOR(Op)                    \
@@ -27,8 +29,6 @@ C10_DECLARE_REGISTRY(
           Op##Functor<CPUContext>,                             \
           FixedType<bool>>>)
 
-#define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \
-  C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
 
 // IDEEPOperator is the base scaffolding of the operators that uses IDEEP. It
 // provides a few operators that are useful to IDEEP specific implementations.
@@ -39,8 +39,6 @@ class IDEEPOperator : public OperatorBase {
         context_(operator_def.device_option()),
         order_(StringToStorageOrder(
             OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
-    OPERATOR_NEEDS_FEATURE(
-        order_ == StorageOrder::NCHW, "Unsupported storage order.");
   }
   virtual ~IDEEPOperator() {}
 
@@ -119,4 +117,19 @@ class IDEEPOperator : public OperatorBase {
       : IDEEPOperator(operator_def, ws) {}                                     \
   virtual ~name() {}
 
+// Convert zero_point scales to min_max scales
+// NOTE:
+//  The scales in operator is saved in FBGEMM format,
+//  while FBGEMM scales are the reciprocals of MKL-DNN scales.
+//  This function is provided to convert scales from FBGEMM to MKL-DNN
+inline ideep::scale_t ConvertScales(
+    const std::vector<float> scales_z) {
+  ideep::scale_t scales (scales_z);
+  for (auto it = scales.begin(); it != scales.end(); it++) {
+    *it = 1.0f / *it;
+  }
+  return scales;
+}
+
+
 } // namespace caffe2
index 05bce30..0d6312f 100644 (file)
@@ -3,6 +3,7 @@
 #include "caffe2/opt/fusion.h"
 
 #ifdef CAFFE2_USE_MKLDNN
+#include <cpuinfo.h>
 #include "caffe2/ideep/ideep_utils.h"
 #endif
 
@@ -78,7 +79,7 @@ bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) {
 }
 
 bool shouldFuseConv(const repr::Conv& conv) {
-  return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
+  return isOnIdeepDevice(conv);
 }
 
 void removeStopGradientForInference(repr::NNModule* nn) {
@@ -110,10 +111,6 @@ void removeStopGradientForInference(repr::NNModule* nn) {
 }
 
 void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
-  // Fusion types:
-  // FUSION_CONV_RELU = 1
-  // FUSION_CONV_SUM = 2
-  // FUSION_CONV_SUM_RELU = 3
   auto conv = repr::nn::get<repr::Conv>(convNode);
   auto annotation = conv->getMutableAnnotation();
   if (!annotation || !isa<Caffe2Annotation>(annotation)) {
@@ -126,19 +123,18 @@ void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
   }
 
   if (op->type() == "ConvFusion") {
-    CAFFE_ENFORCE(fusion_type == 1, "Invalid nest fusion");
+    CAFFE_ENFORCE(fusion_type == FUSION_CONV_RELU, "Invalid nest fusion");
     for (auto& arg : *op->mutable_arg()) {
       if (arg.name() == "fusion_type") {
-        // Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU
-        CAFFE_ENFORCE(arg.i() == 2, "Invalid nest fusion");
-        arg.set_i(3);
+        CAFFE_ENFORCE(arg.i() == FUSION_CONV_SUM, "Invalid nest fusion");
+        arg.set_i(FUSION_CONV_SUM_RELU);
         return;
       }
     }
     return;
   }
 
-  CAFFE_ENFORCE(fusion_type < 3, "Invalid fusion type");
+  CAFFE_ENFORCE_LT(fusion_type, FUSION_CONV_SUM_RELU, "Invalid fusion type");
   op->set_type("ConvFusion");
   auto* arg = op->add_arg();
   arg->set_name("fusion_type");
@@ -224,7 +220,7 @@ bool fuseConvBNAndAffChHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws)
       continue;                                                          \
     }                                                                    \
     name##Tensor.resize(name->get_dims(), name->get_data_type());        \
-    name##Tensor.reorder_from(*name);                                    \
+    name##Tensor.feed_from(*name);                                       \
     CAFFE_ENFORCE(                                                       \
       name##Tensor.is_public_format(), #name " not with public format"); \
     name##Data = static_cast<float*>(name##Tensor.get_data_handle());    \
@@ -263,8 +259,8 @@ bool fuseConvBNAndAffChHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws)
       }
     }
 
-    filter->reorder_from(filterTensor);
-    biasConv->reorder_from(biasConvTensor);
+    filter->feed_from(filterTensor);
+    biasConv->feed_from(biasConvTensor);
     nn->dataFlow.replaceNode(convOutput, bnOrAffChOutput);
 
     nn->dataFlow.deleteNode(bnOrAffChNode);
@@ -282,6 +278,7 @@ void fuseConvBNAndAffChForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
 }
 
 void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
+  CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo");
   // Assume the order of nodes from getMutableNodes conforms to
   // the original topo order of operators
   auto allNodes = nn->dataFlow.getMutableNodes();
@@ -342,11 +339,16 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
     }
 
     auto conv = repr::nn::get<repr::Conv>(convNode);
-    if (!shouldFuseConv(*conv)) {
+    if (!isOnIdeepDevice(*conv)) {
       LOG(WARNING) << "Not a IDEEP operator";
       continue;
     }
 
+    if (conv->getGroup() > 1 && !cpuinfo_has_x86_avx512f()) {
+      LOG(WARNING) << "Not support conv sum fusion with grouped filter";
+      continue;
+    }
+
     auto convOutput = repr::nn::getOutputs(convNode).front();
     repr::NNGraph::NodeRef sumInputX =
         (sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
@@ -366,8 +368,7 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
     auto sumOutput = repr::nn::getOutputs(sumNode).front();
     nn->dataFlow.replaceNode(sumOutput, newOutput);
 
-    // 2 means FUSION_CONV_SUM
-    resetConvForFusion(convNode, 2);
+    resetConvForFusion(convNode, FUSION_CONV_SUM);
     nn->dataFlow.createEdge(sumInputX, convNode);
     nn->dataFlow.createEdge(convNode, newOutput);
 
@@ -405,8 +406,8 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
 
     bool enforce_inplace = false;
     for (const auto& arg : op.arg()) {
-      // Only check FUSION_SUM & FUSION_SUM_RELU
-      if (arg.name() == "fusion_type" && (arg.i() == 2 || arg.i() == 3)) {
+      if (arg.name() == "fusion_type"
+          && (arg.i() == FUSION_CONV_SUM || arg.i() == FUSION_CONV_SUM_RELU)) {
         enforce_inplace = true;
         break;
       }
index 8c40be8..1e1a3ce 100644 (file)
@@ -256,6 +256,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
 
         workspace.SwitchWorkspace(old_ws_name)
 
+
     @given(stride=st.integers(1, 3),
            pad=st.integers(0, 3),
            kernel=st.integers(3, 5),
@@ -410,6 +411,113 @@ class ConvFusionTest(hu.HypothesisTestCase):
            pad=st.integers(0, 3),
            kernel=st.integers(3, 5),
            size=st.integers(8, 20),
+           input_channels=st.integers(7, 17),
+           output_channels=st.integers(5, 15),
+           batch_size=st.integers(1, 3),
+           use_bias=st.booleans(),
+           group=st.integers(2, 5),
+           **mu.gcs)
+    def test_convolution_grouped_sum_relu_fusion(self, stride, pad, kernel, size,
+                             input_channels, output_channels,
+                             batch_size, use_bias, group, gc, dc):
+        conv_S0 = core.CreateOperator(
+            "Conv",
+            ["SX0", "Sw0", "Sb0"] if use_bias else ["SX0", "Sw0"],
+            ["S0"],
+            stride=stride,
+            pad=pad,
+            kernel=kernel,
+            group=group,
+            device_option=dc[0]
+        )
+        conv = core.CreateOperator(
+            "Conv",
+            ["X0", "w0", "b0"] if use_bias else ["X0", "w0"],
+            ["Y0"],
+            stride=stride,
+            pad=pad,
+            kernel=kernel,
+            group=group,
+            device_option=dc[0]
+        )
+        sum = core.CreateOperator(
+            "Sum",
+            ["S0", "Y0"],
+            ["S0"],
+            device_option=dc[0]
+        )
+        relu = core.CreateOperator(
+            "Relu",
+            ["S0"],
+            ["S0"],
+            device_option=dc[0]
+        )
+
+        SX = np.random.rand(
+            batch_size, input_channels * group, size, size).astype(np.float32) - 0.5
+        Sw = np.random.rand(
+                output_channels * group, input_channels, kernel, kernel) \
+            .astype(np.float32) - 0.5
+        Sb = np.random.rand(output_channels * group).astype(np.float32) - 0.5
+        X = np.random.rand(
+            batch_size, input_channels * group, size, size).astype(np.float32) - 0.5
+        w = np.random.rand(
+                output_channels * group, input_channels, kernel, kernel) \
+            .astype(np.float32) - 0.5
+        b = np.random.rand(output_channels * group).astype(np.float32) - 0.5
+
+        old_ws_name = workspace.CurrentWorkspace()
+        workspace.SwitchWorkspace("_device_check_", True)
+        workspace.FeedBlob('SX0', SX, dc[0])
+        workspace.FeedBlob('Sw0', Sw, dc[0])
+        workspace.FeedBlob('Sb0', Sb, dc[0])
+        workspace.FeedBlob('X0', X, dc[0])
+        workspace.FeedBlob('w0', w, dc[0])
+        workspace.FeedBlob('b0', b, dc[0])
+        workspace.RunOperatorOnce(conv_S0)
+        workspace.RunOperatorOnce(conv)
+        workspace.RunOperatorOnce(sum)
+        workspace.RunOperatorOnce(relu)
+        S0 = workspace.FetchBlob('S0')
+
+        workspace.ResetWorkspace()
+        old_net = caffe2_pb2.NetDef()
+        conv_S0_old = caffe2_pb2.OperatorDef()
+        conv_S0_old.CopyFrom(conv_S0)
+        conv_S0_old.device_option.CopyFrom(dc[1])
+        conv_old = caffe2_pb2.OperatorDef()
+        conv_old.CopyFrom(conv)
+        conv_old.device_option.CopyFrom(dc[1])
+        sum_old = caffe2_pb2.OperatorDef()
+        sum_old.CopyFrom(sum)
+        sum_old.device_option.CopyFrom(dc[1])
+        relu_old = caffe2_pb2.OperatorDef()
+        relu_old.CopyFrom(relu)
+        relu_old.device_option.CopyFrom(dc[1])
+        old_net.op.extend([conv_S0_old, conv_old, sum_old, relu_old])
+        workspace.FeedBlob('SX0', SX, dc[1])
+        workspace.FeedBlob('Sw0', Sw, dc[1])
+        workspace.FeedBlob('Sb0', Sb, dc[1])
+        workspace.FeedBlob('X0', X, dc[1])
+        workspace.FeedBlob('w0', w, dc[1])
+        workspace.FeedBlob('b0', b, dc[1])
+        net = core.Net("net")
+        net.Proto().CopyFrom(old_net)
+        optimizeForIDEEP(net)
+        workspace.RunNetOnce(net.Proto())
+        S2 = workspace.FetchBlob('S0')
+        if not np.allclose(S0, S2, atol=0.01, rtol=0.01):
+            print(S2.flatten())
+            print(S0.flatten())
+            print(np.max(np.abs(S2 - S0)))
+            self.assertTrue(False)
+
+        workspace.SwitchWorkspace(old_ws_name)
+
+    @given(stride=st.integers(1, 3),
+           pad=st.integers(0, 3),
+           kernel=st.integers(3, 5),
+           size=st.integers(8, 20),
            input_channels=st.integers(1, 16),
            output_channels=st.integers(1, 16),
            batch_size=st.integers(1, 3),
index ff4971e..4460e0d 100644 (file)
@@ -59,13 +59,13 @@ public:
                   (atensor.get_nelems() == 0 ||
                    atensor.get_data_handle() != nullptr),
                   "Trying to fetch uninitialized tensor");
-    const int numpy_type = CaffeToNumpyType(type_transform(atensor));
+    // NOTE: Only support float so far.
+    const int numpy_type = NPY_FLOAT;
     CAFFE_ENFORCE(
         numpy_type != -1,
         "Unsupported ideep memory data type? This usually should not happen "
         "since ideep memory usually only do float and double.");
     itensor::dims dims = atensor.get_public_format_dims();
-
     std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
 
     result.copied = force_copy || atensor.need_reorder();
@@ -86,7 +86,7 @@ public:
     }
 
     if (result.copied) {
-      atensor.reorder_to(outPtr);
+      atensor.to_public(outPtr);
     }
 
     return result;
@@ -144,7 +144,7 @@ public:
         if (tensor->get_dims() != adims || type != tensor->get_data_type()) {
           tensor->resize(adims, type);
         }
-        tensor->reorder_from(adims, type,
+        tensor->feed_from(adims, type,
                              static_cast<void *>(PyArray_DATA(array)));
     }
 #else