reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# bias_add
-@reg.register_compute("nn.bias_add")
-def compute_bias_add(attrs, inputs, out_dtype, target):
- """Compute definition of conv2d_transpose"""
- axis = attrs.axis
- bias = inputs[1]
- data_ndim = len(inputs[0].shape)
- if axis < 0:
- axis = axis + data_ndim
- num_newaxis = data_ndim - axis - 1
-
- if num_newaxis:
- bias = topi.expand_dims(bias, axis=1, num_newaxis=num_newaxis)
- return [topi.add(inputs[0], bias)]
-
reg.register_schedule("nn.bias_add", schedule_injective)
reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h>
#include <topi/nn.h>
+#include <topi/nn/bias_add.h>
#include <topi/nn/softmax.h>
#include <topi/nn/flatten.h>
#include <vector>
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1)
-.add_type_rel("BiasAdd", BiasAddRel);
+.add_type_rel("BiasAdd", BiasAddRel)
+.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<Tensor>& inputs,
+ const Type& out_type, const Target& target) {
+ const auto* param = attrs.as<BiasAddAttrs>();
+ return tvm::Array<tvm::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
+});
// relay.nn.dense
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \brief bias_add op constructions
+ * \file nn/bias_add.h
+ */
+#ifndef TOPI_NN_BIAS_ADD_H_
+#define TOPI_NN_BIAS_ADD_H_
+
+#include <string>
+
+#include "topi/tags.h"
+#include "topi/broadcast.h"
+#include "topi/transform.h"
+#include "tvm/tvm.h"
+
+namespace topi {
+namespace nn {
+
+/*!
+* \brief Creates an operation that calculates data + bias
+*
+* \param data Tensor with shape [batch, in_dim]
+* \param bias Tensor with shape [batch].
+*
+* \return Tensor with shape [batch, in_dim]
+*/
+inline tvm::Tensor bias_add(const tvm::Tensor& data, const tvm::Tensor& bias, int axis) {
+ int data_ndim = data->shape.size();
+ if (axis < 0) {
+ axis += data_ndim;
+ }
+ int num_newaxis = data_ndim - axis - 1;
+ return add(data, (num_newaxis ? expand_dims(bias, 1, num_newaxis) : bias));
+}
+} // namespace nn
+} // namespace topi
+#endif // TOPI_NN_BIAS_ADD_H_
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <topi/reduction.h>
#include <topi/transform.h>
+#include <topi/nn/bias_add.h>
#include <topi/nn/bnn.h>
#include <topi/nn/dense.h>
#include <topi/nn/dilate.h>
*rv = nn::dense(args[0], args[1], args[2]);
});
+/* Ops from nn/bias_add.h */
+TVM_REGISTER_GLOBAL("topi.nn.bias_add")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = nn::bias_add(args[0], args[1], args[2]);
+ });
+
/* Ops from nn/batch_matmul.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {