From e0d286a11267a9431263f76164e6e9099bdb3558 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 21 Oct 2019 13:40:55 -0700 Subject: [PATCH] [Relay][Pass] Count MAC for BatchMatMul (#4157) * count MAC for BatchMatMul * update doc --- src/relay/pass/mac_count.cc | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 48a0dfb..000783c 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -66,7 +66,7 @@ int64_t ConvMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK(args.size() == 2) + CHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); @@ -74,13 +74,13 @@ int64_t ConvMacCount(const Call& call_node) { std::string data_layout = conv_2d_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK(C_ind != -1) + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; - CHECK(kernel_size.size() == 2) + CHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; @@ -99,7 +99,7 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK(args.size() == 2) + CHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D Transpose node should be 2."; const auto* conv_2d_transpose_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); @@ -107,13 +107,13 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) { std::string data_layout = conv_2d_transpose_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK(C_ind != -1) + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_transpose_attr->kernel_size; - CHECK(kernel_size.size() == 2) + CHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D Transpose should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; @@ -132,7 +132,7 @@ int64_t DenseMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK(args.size() == 2) + CHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2."; const auto* data_type = args[0]->checked_type().as(); const auto* weight_type = args[1]->checked_type().as(); @@ -144,12 +144,28 @@ int64_t DenseMacCount(const Call& call_node) { int64_t d2 = static_cast(data_shape[1].as()->value); int64_t d3 = static_cast(weight_shape[0].as()->value); int64_t d4 = static_cast(weight_shape[1].as()->value); - CHECK(d2 == d4) + CHECK_EQ(d2, d4) << "The dimensions of input arguments do not match."; int64_t count = d1 * d2 * d3; return count; } +int64_t BatchMatmulMacCount(const Call& call_node) { + if (!call_node->checked_type_.defined()) { + LOG(WARNING) << "The infer type pass should be called before the mac count pass"; + return 0; + } + Array args = call_node->args; + CHECK_EQ(args.size(), 2); + Array x_shape = args[0]->checked_type().as()->shape; + Array y_shape = args[1]->checked_type().as()->shape; + int64_t batch = x_shape[0].as()->value; + int64_t m = x_shape[1].as()->value; + int64_t k = x_shape[2].as()->value; + int64_t n = y_shape[1].as()->value; + return batch * m * k * n; +} + RELAY_REGISTER_OP("nn.conv2d") .set_attr("FMacCount", ConvMacCount); @@ -159,14 +175,17 @@ RELAY_REGISTER_OP("nn.conv2d_transpose") RELAY_REGISTER_OP("nn.dense") .set_attr("FMacCount", DenseMacCount); +RELAY_REGISTER_OP("nn.batch_matmul") +.set_attr("FMacCount", BatchMatmulMacCount); + class MacCounter : private ExprVisitor { public: MacCounter() { count_ = 0; } static int64_t GetTotalMacNumber(const Expr& expr) { - LOG(INFO) << "This pass only counts MACs in direct CONV 2D, " - << "CONV 2D Transpose and Dense ops"; + LOG(INFO) << "This pass only counts MACs in direct conv2d, " + << "conv2d_transpose, dense, and batch_matmul ops"; MacCounter counter; counter(expr); return counter.count_; -- 2.7.4