[RELAY][PASS] detect depthwise conv2d in mac_count pass (#3083)
authoreqy <eddieyan101@gmail.com>
Tue, 14 May 2019 12:34:16 +0000 (05:34 -0700)
committerLianmin Zheng <lianminzheng@gmail.com>
Tue, 14 May 2019 12:34:16 +0000 (20:34 +0800)
* check in

* use groups

* CHECK_EQ

* trigger CI

* Update mac_count.cc

* trigger CI

* trigger CI

src/relay/pass/mac_count.cc
tests/python/relay/test_pass_mac_count.py

index c9ee4ee..3d77fab 100644 (file)
@@ -30,7 +30,9 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pass.h>
 #include <tvm/data_layout.h>
+#include "pattern_util.h"
 
 namespace tvm {
 namespace relay {
@@ -65,7 +67,7 @@ int64_t ConvMacCount(const Call& call_node) {
   }
   Array<Expr> args = call_node->args;
   CHECK(args.size() == 2)
-      << "The number of input arguments of a CONV 2D node should be 2.";
+    << "The number of input arguments of a CONV 2D node should be 2.";
   const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
   Array<IndexExpr> data_shape = data_type->shape;
@@ -73,18 +75,21 @@ int64_t ConvMacCount(const Call& call_node) {
   int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
   CHECK(C_ind != -1)
-      << "There is no input channel dimension.";
+    << "There is no input channel dimension.";
   int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
   if (c_ind != -1)
     input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
   Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
   CHECK(kernel_size.size() == 2)
-      << "The dimension of the kernel size in Conv 2D should be 2.";
+    << "The dimension of the kernel in Conv 2D should be 2.";
   const auto* expr = call_node->checked_type().as<TensorTypeNode>();
   Array<IndexExpr> output_tensor = expr->shape;
   CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
-      << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
-  int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
+    << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
+  int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
+  CHECK_EQ(input_channel % conv_2d_attr->groups, 0)
+  << "The number of input channels is not divisble by groups.";
+  count *= input_channel/conv_2d_attr->groups;
   return count;
 }
 
index 5a975fd..98ba1ad 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Unit tests for MAC counter."""
+import numpy as np
 import tvm
 from tvm import relay
 
@@ -99,7 +100,35 @@ def test_simple_network():
     expect_count = 231411712
     assert compute_count == expect_count
 
+def test_depthwise_conv2d():
+    batch_size = 1
+    dshape = (batch_size, 64, 56, 56)
+    weight_conv = relay.var("weight_depthwiseconv", shape=(64, 1, 3, 3))
+    data1 = relay.var("data1", shape=dshape)
+    data2 = relay.var("data2", shape=dshape)
+    depthwise_conv2d_1 = relay.nn.conv2d(
+        data1,
+        weight_conv,
+        kernel_size=(3, 3),
+        padding=(1, 1),
+        groups=64)
+    depthwise_conv2d_2 = relay.nn.conv2d(
+        data2,
+        weight_conv,
+        kernel_size=(3, 3),
+        padding=(1, 1),
+        groups=64)
+    add = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+    func = relay.Function([data1, data2, weight_conv],
+                            relay.Tuple(tvm.convert([depthwise_conv2d_1,
+                                                    depthwise_conv2d_2,
+                                                    add])))
+    func = relay.ir_pass.infer_type(func)
+    compute_count = relay.ir_pass.get_total_mac_number(func)
+    assert compute_count == 2 * np.prod(dshape) * 3*3
+
 if __name__ == "__main__":
     test_conv()
     test_gemm()
     test_simple_network()
+    test_depthwise_conv2d()