[Relay][Pass] Count MAC for BatchMatMul (#4157)
authorHaichen Shen <shenhaichen@gmail.com>
Mon, 21 Oct 2019 20:40:55 +0000 (13:40 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Mon, 21 Oct 2019 20:40:55 +0000 (13:40 -0700)
* count MAC for BatchMatMul

* update doc

src/relay/pass/mac_count.cc

index 48a0dfb..000783c 100644 (file)
@@ -66,7 +66,7 @@ int64_t ConvMacCount(const Call& call_node) {
     return 0;
   }
   Array<Expr> args = call_node->args;
-  CHECK(args.size() == 2)
+  CHECK_EQ(args.size(), 2)
     << "The number of input arguments of a CONV 2D node should be 2.";
   const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
@@ -74,13 +74,13 @@ int64_t ConvMacCount(const Call& call_node) {
   std::string data_layout = conv_2d_attr->data_layout;
   int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
-  CHECK(C_ind != -1)
+  CHECK_NE(C_ind, -1)
     << "There is no input channel dimension.";
   int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
   if (c_ind != -1)
     input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
   Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
-  CHECK(kernel_size.size() == 2)
+  CHECK_EQ(kernel_size.size(), 2)
     << "The dimension of the kernel in Conv 2D should be 2.";
   const auto* expr = call_node->checked_type().as<TensorTypeNode>();
   Array<IndexExpr> output_tensor = expr->shape;
@@ -99,7 +99,7 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
     return 0;
   }
   Array<Expr> args = call_node->args;
-  CHECK(args.size() == 2)
+  CHECK_EQ(args.size(), 2)
     << "The number of input arguments of a CONV 2D Transpose node should be 2.";
   const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
@@ -107,13 +107,13 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
   std::string data_layout = conv_2d_transpose_attr->data_layout;
   int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
-  CHECK(C_ind != -1)
+  CHECK_NE(C_ind, -1)
     << "There is no input channel dimension.";
   int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
   if (c_ind != -1)
     input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
   Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
-  CHECK(kernel_size.size() == 2)
+  CHECK_EQ(kernel_size.size(), 2)
     << "The dimension of the kernel in Conv 2D Transpose should be 2.";
   const auto* expr = call_node->checked_type().as<TensorTypeNode>();
   Array<IndexExpr> output_tensor = expr->shape;
@@ -132,7 +132,7 @@ int64_t DenseMacCount(const Call& call_node) {
     return 0;
   }
   Array<Expr> args = call_node->args;
-  CHECK(args.size() == 2)
+  CHECK_EQ(args.size(), 2)
       << "The number of input arguments of a Dense node should be 2.";
   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
   const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
@@ -144,12 +144,28 @@ int64_t DenseMacCount(const Call& call_node) {
   int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
   int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
   int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
-  CHECK(d2 == d4)
+  CHECK_EQ(d2, d4)
     << "The dimensions of input arguments do not match.";
   int64_t count = d1 * d2 * d3;
   return count;
 }
 
+int64_t BatchMatmulMacCount(const Call& call_node) {
+  if (!call_node->checked_type_.defined()) {
+    LOG(WARNING) << "The infer type pass should be called before the mac count pass";
+    return 0;
+  }
+  Array<Expr> args = call_node->args;
+  CHECK_EQ(args.size(), 2);
+  Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
+  Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
+  int64_t batch = x_shape[0].as<IntImm>()->value;
+  int64_t m = x_shape[1].as<IntImm>()->value;
+  int64_t k = x_shape[2].as<IntImm>()->value;
+  int64_t n = y_shape[1].as<IntImm>()->value;
+  return batch * m * k * n;
+}
+
 RELAY_REGISTER_OP("nn.conv2d")
 .set_attr<FMacCount>("FMacCount", ConvMacCount);
 
@@ -159,14 +175,17 @@ RELAY_REGISTER_OP("nn.conv2d_transpose")
 RELAY_REGISTER_OP("nn.dense")
 .set_attr<FMacCount>("FMacCount", DenseMacCount);
 
+RELAY_REGISTER_OP("nn.batch_matmul")
+.set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
+
 class MacCounter : private ExprVisitor {
  public:
   MacCounter() {
     count_ = 0;
   }
   static int64_t GetTotalMacNumber(const Expr& expr) {
-    LOG(INFO) << "This pass only counts MACs in direct CONV 2D, "
-              << "CONV 2D Transpose and Dense ops";
+    LOG(INFO) << "This pass only counts MACs in direct conv2d, "
+              << "conv2d_transpose, dense, and batch_matmul ops";
     MacCounter counter;
     counter(expr);
     return counter.count_;