Enable conv+add fusion, same as conv+sum (#15268)
authorGu, Jinghui <jinghui.gu@intel.com>
Mon, 7 Jan 2019 22:10:27 +0000 (14:10 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 7 Jan 2019 22:42:45 +0000 (14:42 -0800)
Summary:
Enable conv+add fusion, same as conv+sum

Caution: only element-wise add is supported on IDEEP without scalar
broadcast. Otherwise, the fusion is illegal.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15268

Differential Revision: D13577375

Pulled By: yinghai

fbshipit-source-id: 92c9c4b667c5ca5f7a262a5bffaa8aa68eeff3bd

caffe2/opt/optimize_ideep.cc
caffe2/python/ideep/convfusion_op_test.py

index 14360f8..e4a9feb 100644 (file)
@@ -53,6 +53,15 @@ caffe2::OperatorDef* getMutableOpDef(repr::NeuralNetOperator& nnOp) {
   return dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
 }
 
+bool isOpType(const repr::NNGraph::NodeRef& nodeRef, string typeName) {
+  if (!repr::nn::is<repr::NeuralNetOperator>(nodeRef)) {
+    return false;
+  }
+  auto op = repr::nn::get<repr::NeuralNetOperator>(nodeRef);
+  auto opDef = getOpDef(*op);
+  return opDef.type() == typeName;
+}
+
 bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) {
   // We only want to fuse for IDEEP convs
   const auto& op = getOpDef(nnOp);
@@ -245,11 +254,13 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
       continue;
     }
 
-    if (!repr::nn::is<repr::Sum>(sumNode)) {
+    // CAUTION: On IDEEP device, only element-wise Add operator is
+    // supported yet. It totally works as element-wise sum without scalar broadcast.
+    if (!repr::nn::is<repr::Sum>(sumNode) && !isOpType(sumNode, "Add")) {
       continue;
     }
 
-    auto sum = repr::nn::get<repr::Sum>(sumNode);
+    auto sum = repr::nn::get<repr::NeuralNetOperator>(sumNode);
     if (!isOnIdeepDevice(*sum)) {
       LOG(WARNING) << "Not a IDEEP operator";
       continue;
index 6445655..de66a4d 100644 (file)
@@ -125,10 +125,11 @@ class ConvFusionTest(hu.HypothesisTestCase):
            batch_size=st.integers(1, 3),
            use_bias=st.booleans(),
            group=st.integers(1, 1),
+           sum_add=st.sampled_from(["Sum", "Add"]),
            **mu.gcs)
     def test_convolution_sum_fusion(self, stride, pad, kernel, size,
                              input_channels, output_channels,
-                             batch_size, use_bias, group, gc, dc):
+                             batch_size, use_bias, group, sum_add, gc, dc):
         conv_S0 = core.CreateOperator(
             "Conv",
             ["SX0", "Sw0", "Sb0"] if use_bias else ["SX0", "Sw0"],
@@ -150,7 +151,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
             device_option=dc[0]
         )
         sum = core.CreateOperator(
-            "Sum",
+            sum_add,
             ["S0", "Y0"],
             ["S0"],
             device_option=dc[0]
@@ -264,10 +265,11 @@ class ConvFusionTest(hu.HypothesisTestCase):
            batch_size=st.integers(1, 3),
            use_bias=st.booleans(),
            group=st.integers(1, 1),
+           sum_add=st.sampled_from(["Sum", "Add"]),
            **mu.gcs)
     def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size,
                              input_channels, output_channels,
-                             batch_size, use_bias, group, gc, dc):
+                             batch_size, use_bias, group, sum_add, gc, dc):
         conv_S0 = core.CreateOperator(
             "Conv",
             ["SX0", "Sw0", "Sb0"] if use_bias else ["SX0", "Sw0"],
@@ -289,7 +291,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
             device_option=dc[0]
         )
         sum = core.CreateOperator(
-            "Sum",
+            sum_add,
             ["S0", "Y0"],
             ["S0"],
             device_option=dc[0]