Implement relay nn.bias_add compute in C++ (#3027)
authorYinghai Lu <yinghai@fb.com>
Wed, 17 Apr 2019 20:06:30 +0000 (13:06 -0700)
committerLeyuan Wang <laurawly@gmail.com>
Wed, 17 Apr 2019 20:06:30 +0000 (13:06 -0700)
* Implement nn.bias_add compute in C++

* Address comments

* Remove unnecessary check

python/tvm/relay/op/nn/_nn.py
src/relay/op/nn/nn.cc
topi/include/topi/nn/bias_add.h [new file with mode: 0644]
topi/src/topi.cc

index 5a47b1d..e60c01c 100644 (file)
@@ -182,20 +182,6 @@ def schedule_conv2d_transpose(attrs, outs, target):
 reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # bias_add
-@reg.register_compute("nn.bias_add")
-def compute_bias_add(attrs, inputs, out_dtype, target):
-    """Compute definition of conv2d_transpose"""
-    axis = attrs.axis
-    bias = inputs[1]
-    data_ndim = len(inputs[0].shape)
-    if axis < 0:
-        axis = axis + data_ndim
-    num_newaxis = data_ndim - axis - 1
-
-    if num_newaxis:
-        bias = topi.expand_dims(bias, axis=1, num_newaxis=num_newaxis)
-    return [topi.add(inputs[0], bias)]
-
 reg.register_schedule("nn.bias_add", schedule_injective)
 reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
 
index ae25662..4141e60 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -28,6 +28,7 @@
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/attrs/image.h>
 #include <topi/nn.h>
+#include <topi/nn/bias_add.h>
 #include <topi/nn/softmax.h>
 #include <topi/nn/flatten.h>
 #include <vector>
@@ -90,7 +91,12 @@ RELAY_REGISTER_OP("nn.bias_add")
 .add_argument("data", "nD Tensor", "Input data.")
 .add_argument("bias", "1D Tensor", "Bias.")
 .set_support_level(1)
-.add_type_rel("BiasAdd", BiasAddRel);
+.add_type_rel("BiasAdd", BiasAddRel)
+.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<Tensor>& inputs,
+                                        const Type& out_type, const Target& target) {
+    const auto* param = attrs.as<BiasAddAttrs>();
+    return tvm::Array<tvm::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
+});
 
 
 // relay.nn.dense
diff --git a/topi/include/topi/nn/bias_add.h b/topi/include/topi/nn/bias_add.h
new file mode 100644 (file)
index 0000000..fb4ae30
--- /dev/null
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \brief bias_add op constructions
+ * \file nn/bias_add.h
+ */
+#ifndef TOPI_NN_BIAS_ADD_H_
+#define TOPI_NN_BIAS_ADD_H_
+
+#include <string>
+
+#include "topi/tags.h"
+#include "topi/broadcast.h"
+#include "topi/transform.h"
+#include "tvm/tvm.h"
+
+namespace topi {
+namespace nn {
+
+/*!
+* \brief Creates an operation that calculates data + bias
+*
+* \param data Tensor with shape [batch, in_dim]
+* \param bias Tensor with shape [batch].
+*
+* \return Tensor with shape [batch, in_dim]
+*/
+inline tvm::Tensor bias_add(const tvm::Tensor& data, const tvm::Tensor& bias, int axis) {
+  int data_ndim = data->shape.size();
+  if (axis < 0) {
+    axis += data_ndim;
+  }
+  int num_newaxis = data_ndim - axis - 1;
+  return add(data, (num_newaxis ? expand_dims(bias, 1, num_newaxis) : bias));
+}
+}  // namespace nn
+}  // namespace topi
+#endif  // TOPI_NN_BIAS_ADD_H_
index 47e999c..c583f1c 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -36,6 +36,7 @@
 #include <topi/reduction.h>
 #include <topi/transform.h>
 
+#include <topi/nn/bias_add.h>
 #include <topi/nn/bnn.h>
 #include <topi/nn/dense.h>
 #include <topi/nn/dilate.h>
@@ -400,6 +401,12 @@ TVM_REGISTER_GLOBAL("topi.nn.dense")
   *rv = nn::dense(args[0], args[1], args[2]);
   });
 
+/* Ops from nn/bias_add.h */
+TVM_REGISTER_GLOBAL("topi.nn.bias_add")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = nn::bias_add(args[0], args[1], args[2]);
+  });
+
 /* Ops from nn/batch_matmul.h */
 TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
 .set_body([](TVMArgs args, TVMRetValue *rv) {