From 00b23049f3cd0509f37d7f0d3c4cbee4cacf51c7 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 6 Apr 2020 14:30:53 -0700 Subject: [PATCH] [Topi] Breakdown topi.cc into smaller files (#5253) * [Topi] Breakdown topi.cc into smaller files * add missing file --- topi/include/topi/broadcast.h | 6 +- topi/include/topi/util.h | 53 +++ topi/src/broadcast.cc | 84 ++++ topi/src/elemwise.cc | 154 ++++++++ topi/src/nn.cc | 185 +++++++++ topi/src/reduction.cc | 75 ++++ topi/src/schedule.cc | 357 +++++++++++++++++ topi/src/topi.cc | 893 ------------------------------------------ topi/src/transform.cc | 177 +++++++++ topi/src/vision.cc | 39 ++ 10 files changed, 1127 insertions(+), 896 deletions(-) create mode 100644 topi/include/topi/util.h create mode 100644 topi/src/broadcast.cc create mode 100644 topi/src/elemwise.cc create mode 100644 topi/src/nn.cc create mode 100644 topi/src/reduction.cc create mode 100644 topi/src/schedule.cc delete mode 100644 topi/src/topi.cc create mode 100644 topi/src/transform.cc create mode 100644 topi/src/vision.cc diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 30bc584..c9b12d3 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -45,9 +45,9 @@ namespace topi { * \return A Tensor whose op member is a broadcast operation */ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, - const tvm::Array& output_shape, - std::string name = "T_broadcast_to", - std::string tag = kBroadcast) { + const tvm::Array& output_shape, + std::string name = "T_broadcast_to", + std::string tag = kBroadcast) { CHECK_GE(output_shape.size(), t->shape.size()) << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape << "\nvs\ninput: " << t; diff --git a/topi/include/topi/util.h b/topi/include/topi/util.h new file mode 100644 index 0000000..95c5c55 --- /dev/null +++ b/topi/include/topi/util.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Topi utility function + * \file topi/util.h + */ +#ifndef TOPI_UTIL_H_ +#define TOPI_UTIL_H_ + +#include +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +/*! \brief Canonicalize an argument that may be Array or int to Array */ +inline Array ArrayOrInt(TVMArgValue arg) { + if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) { + Array result; + result.push_back(arg.operator int()); + return result; + } else { + return arg; + } +} + +inline bool IsTensorType(TVMArgValue arg) { + return (arg.type_code() == kTVMObjectHandle && + static_cast( + arg.value().v_handle)->IsInstance()); +} + +} // namespace topi +#endif // TOPI_UTIL_H_ diff --git a/topi/src/broadcast.cc b/topi/src/broadcast.cc new file mode 100644 index 0000000..3ae3dae --- /dev/null +++ b/topi/src/broadcast.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of broadcast operators +* \file broadcast.cc +*/ +#include +#include + +#include +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_REGISTER_GLOBAL(OpName) \ + .set_body([](TVMArgs args, TVMRetValue *rv) { \ + bool lhs_is_tensor = IsTensorType(args[0]); \ + bool rhs_is_tensor = IsTensorType(args[1]); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), \ + args[1].operator tvm::te::Tensor()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), \ + args[1].operator tvm::te::Tensor()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), \ + args[1].operator tvm::PrimExpr()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), \ + args[1].operator tvm::PrimExpr()); \ + } \ + }); \ + +TOPI_REGISTER_BCAST_OP("topi.add", topi::add); +TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); +TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply); +TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide); +TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide); +TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod); +TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod); +TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); +TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum); +TOPI_REGISTER_BCAST_OP("topi.power", topi::power); +TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); +TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); +TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); +TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and); +TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or); +TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor); +TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); +TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); +TOPI_REGISTER_BCAST_OP("topi.less", topi::less); +TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal); +TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); +TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); +TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); + +TVM_REGISTER_GLOBAL("topi.broadcast_to") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = broadcast_to(args[0], args[1]); + }); + +} // namespace topi diff --git a/topi/src/elemwise.cc b/topi/src/elemwise.cc new file mode 100644 index 0000000..ab9f6fd --- /dev/null +++ b/topi/src/elemwise.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of elemwise operators +* \file elemwise.cc +*/ +#include +#include + +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +TVM_REGISTER_GLOBAL("topi.exp") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = exp(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.fast_exp") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = fast_exp(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.erf") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = erf(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.tan") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = tan(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.cos") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = cos(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.sin") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = sin(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.tanh") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = tanh(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.fast_tanh") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = fast_tanh(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.atan") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = atan(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.sigmoid") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = sigmoid(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.sqrt") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = sqrt(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.rsqrt") +.set_body([](TVMArgs args, TVMRetValue *rv) { +*rv = rsqrt(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.log") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = log(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.identity") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = identity(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.negative") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = negative(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.clip") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = clip(args[0], args[1], args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.cast") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = cast(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.reinterpret") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = reinterpret(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.elemwise_sum") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = elemwise_sum(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.sign") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = sign(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.full") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = full(args[0], args[1], args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.full_like") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = full_like(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.logical_not") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = logical_not(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.bitwise_not") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = bitwise_not(args[0]); + }); + +} // namespace topi diff --git a/topi/src/nn.cc b/topi/src/nn.cc new file mode 100644 index 0000000..77b208d --- /dev/null +++ b/topi/src/nn.cc @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of NN operators +* \file nn.cc +*/ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +/* Ops from nn.h */ +TVM_REGISTER_GLOBAL("topi.nn.relu") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = relu(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = leaky_relu(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.prelu") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = prelu(args[0], args[1], args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.pad") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = pad(args[0], args[1], args[2], args[3]); + }); + +/* Ops from nn/dense.h */ +TVM_REGISTER_GLOBAL("topi.nn.dense") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::dense(args[0], args[1], args[2], args[3]); + }); + +/* Ops from nn/bias_add.h */ +TVM_REGISTER_GLOBAL("topi.nn.bias_add") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::bias_add(args[0], args[1], args[2]); + }); + +/* Ops from nn/batch_matmul.h */ +TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::batch_matmul(args[0], args[1]); + }); + +/* Ops from nn/dilate.h */ +TVM_REGISTER_GLOBAL("topi.nn.dilate") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::dilate(args[0], args[1]); + }); + +/* Ops from nn/flatten.h */ +TVM_REGISTER_GLOBAL("topi.nn.flatten") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::flatten(args[0]); + }); + +/* Ops from nn/mapping.h */ +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::scale_shift_nchw(args[0], args[1], args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::scale_shift_nhwc(args[0], args[1], args[2]); + }); + +/* Ops from nn/pooling.h */ +TVM_REGISTER_GLOBAL("topi.nn.pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::pool(args[0], args[1], args[2], args[3], + static_cast(static_cast(args[4])), + args[5], args[6], args[7]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.pool_grad") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4], + static_cast(static_cast(args[5])), + args[6], args[7], args[8]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.global_pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::global_pool(args[0], + static_cast(static_cast(args[1])), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::adaptive_pool(args[0], args[1], + static_cast(static_cast(args[2])), + args[3]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::adaptive_pool3d(args[0], args[1], + static_cast(static_cast(args[2])), + args[3]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.pool1d") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::pool1d(args[0], args[1], args[2], args[3], + static_cast(static_cast(args[4])), + args[5], args[6], args[7]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.pool3d") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::pool3d(args[0], args[1], args[2], args[3], + static_cast(static_cast(args[4])), + args[5], args[6], args[7]); + }); + +/* Ops from nn/softmax.h */ +TVM_REGISTER_GLOBAL("topi.nn.softmax") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::softmax(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.log_softmax") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::log_softmax(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.lrn") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::lrn(args[0], args[1], args[2], + static_cast(args[3]), + static_cast(args[4]), + static_cast(args[5])); + }); + +/* Ops from nn/bnn.h */ +TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::binarize_pack(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.binary_dense") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::binary_dense(args[0], args[1]); + }); + +} // namespace topi diff --git a/topi/src/reduction.cc b/topi/src/reduction.cc new file mode 100644 index 0000000..e1fdada --- /dev/null +++ b/topi/src/reduction.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of reduction operators +* \file reduction.cc +*/ +#include +#include + +#include +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +TVM_REGISTER_GLOBAL("topi.sum") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.min") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.max") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.argmin") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.argmax") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.prod") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.all") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.any") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); + }); + +} // namespace topi diff --git a/topi/src/schedule.cc b/topi/src/schedule.cc new file mode 100644 index 0000000..936f390 --- /dev/null +++ b/topi/src/schedule.cc @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of TVM schedules +* \file schedule.cc +*/ +#define TOPI_REDUCE_ATLEAST1D 0 + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +TVM_REGISTER_GLOBAL("topi.TEST_create_target") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = tvm::Target::Create(args[0]); + }); + +/* Generic schedules */ +TVM_REGISTER_GLOBAL("topi.generic.default_schedule") +.set_body([](TVMArgs args, TVMRetValue *rv) { + if (args[2]) { + *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]); + } else { + *rv = topi::generic::default_schedule(args[0], args[1]); + } + }); + +TVM_REGISTER_GLOBAL("topi.generic.schedule_extern") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::generic::schedule_extern(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); + }); + +/* x86 schedules */ +TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::x86::schedule_binarize_pack(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::x86::schedule_binary_dense(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.x86.default_schedule") +.set_body([](TVMArgs args, TVMRetValue *rv) { + if (args[2]) { + *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]); + } else { + *rv = topi::x86::default_schedule(args[0], args[1]); + } + }); + +TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::x86::schedule_injective(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); + }); + +/* ROCm schedules */ +TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_dense(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_injective(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_pool(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_global_pool(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_reduce(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_softmax(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_lrn(args[0]); + }); + +/* CUDA schedules */ +TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_dense(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_injective(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_pool(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_global_pool(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_reduce(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_softmax(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_lrn(args[0]); + }); + +/* Utility functions */ +TVM_REGISTER_GLOBAL("topi.util.is_empty_shape") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::detail::is_empty_shape(args[0]); + }); + +TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); + }); + +/*! \brief Builder function for instantiating schedules. */ +using FTVMScheduleBuilder = std::function< + tvm::te::Schedule(const tvm::Target& target, const tvm::Array& outs)>; + +/*! + * \brief Helper function for registering generic functions matching the + * FTVMScheduleBuilder signature. The schedule builder function is wrapped + * with a PackedFunc suitable for passing to a tvm::GenericFunc. + * + * \param builder The schedule builder to wrap. + * + * \return The wrapped schedule builder + */ +inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { + return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { + auto target = Target::Current(false); + Array outs; + ObjectRef argNodeRef = args[0]; + if (argNodeRef->type_index() == outs->type_index()) { + outs = args[0]; + } else { + outs = Array { args[0] }; + } + + *ret = builder(target, outs); + }); +} + +TVM_REGISTER_GENERIC_FUNC(schedule_injective) +.set_default(WrapSchedule(topi::generic::schedule_injective)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective)) +.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective)); + +TVM_REGISTER_GENERIC_FUNC(schedule_softmax) +.set_default(WrapSchedule(topi::generic::default_schedule)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) +.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax)); + +TVM_REGISTER_GENERIC_FUNC(schedule_dense) +.set_default(WrapSchedule(topi::generic::default_schedule)) +.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) +.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); + +TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) +.set_default(WrapSchedule(topi::generic::default_schedule)); + +TVM_REGISTER_GENERIC_FUNC(schedule_pool) +.set_default(WrapSchedule(topi::generic::default_schedule)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) +.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool)); + +TVM_REGISTER_GENERIC_FUNC(schedule_global_pool) +.set_default(WrapSchedule(topi::generic::default_schedule)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) +.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool)); + +TVM_REGISTER_GENERIC_FUNC(schedule_reduce) +.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline)) +.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce)); + +TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack) +.set_default(WrapSchedule(topi::generic::default_schedule)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack)); + +TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) +.set_default(WrapSchedule(topi::generic::default_schedule)) +.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); + +/*! \brief Builder function for instantiating schedules from existing schedules. */ +using FTVMScheduleFromExistingBuilder = std::function< + tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>; + +/*! + * \brief Helper function for registering generic functions matching the + * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped + * with a PackedFunc suitable for passing to a tvm::GenericFunc. + * + * \param builder The schedule builder to wrap. + * + * \return The wrapped schedule builder + */ +inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { + return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { + *ret = builder(args[0], args[1]); + }); +} + +TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) +.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) +.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) +.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting( + topi::cuda::schedule_injective_from_existing)); + +/*! \brief Builder function for instantiating dense ops. */ +using FTVMDenseOpBuilder = std::function; + +/*! +* \brief Helper function for registering dense ops matching the +* FTVMDenseOpBuilder signature. The op builder function is wrapped +* with a PackedFunc suitable for passing to a tvm::GenericFunc. +* +* \param builder The op builder to wrap. +* +* \return The wrapped op builder +*/ +inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { + return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { + auto target = Target::Current(false); + Tensor data = args[0]; + Tensor weight = args[1]; + Tensor bias = args[2]; + DataType out_dtype = args[3]; + + *ret = builder(target, data, weight, bias, out_dtype); + }); +} + +TVM_REGISTER_GENERIC_FUNC(dense) +.set_default(WrapDenseOp([](const Target& target, + const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, + const tvm::te::Tensor& bias, + const DataType& out_dtype) { + return topi::nn::dense(data, weight, bias, out_dtype); +})) +.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda)) +.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm)); + +} // namespace topi diff --git a/topi/src/topi.cc b/topi/src/topi.cc deleted file mode 100644 index 3a3175c..0000000 --- a/topi/src/topi.cc +++ /dev/null @@ -1,893 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! -* \brief Registration of TVM operators and schedules -* \file topi.cc -*/ -#define TOPI_REDUCE_ATLEAST1D 0 - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include - -namespace topi { - -using namespace tvm; -using namespace tvm::runtime; - -/*! \brief Canonicalize an argument that may be Array or int to Array */ -Array ArrayOrInt(TVMArgValue arg) { - if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) { - Array result; - result.push_back(arg.operator int()); - return result; - } else { - return arg; - } -} - -inline bool IsTensorType(TVMArgValue arg) { - return (arg.type_code() == kTVMObjectHandle && - static_cast( - arg.value().v_handle)->IsInstance()); -} - - -TVM_REGISTER_GLOBAL("topi.TEST_create_target") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tvm::Target::Create(args[0]); - }); - -/* Ops from broadcast.h */ -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName) \ - .set_body([](TVMArgs args, TVMRetValue *rv) { \ - bool lhs_is_tensor = IsTensorType(args[0]); \ - bool rhs_is_tensor = IsTensorType(args[1]); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::PrimExpr()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::PrimExpr()); \ - } \ - }); \ - -TOPI_REGISTER_BCAST_OP("topi.add", topi::add); -TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); -TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply); -TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide); -TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide); -TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod); -TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod); -TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); -TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum); -TOPI_REGISTER_BCAST_OP("topi.power", topi::power); -TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); -TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); -TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); -TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and); -TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or); -TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor); -TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); -TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); -TOPI_REGISTER_BCAST_OP("topi.less", topi::less); -TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal); -TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); -TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); -TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); - -TVM_REGISTER_GLOBAL("topi.broadcast_to") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = broadcast_to(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.logical_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = logical_not(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.bitwise_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = bitwise_not(args[0]); - }); - -/* Ops from elemwise.h */ -TVM_REGISTER_GLOBAL("topi.exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = exp(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.fast_exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = fast_exp(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = erf(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.tan") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tan(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.cos") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = cos(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.sin") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sin(args[0]); - }); -TVM_REGISTER_GLOBAL("topi.tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tanh(args[0]); - }); -TVM_REGISTER_GLOBAL("topi.fast_tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = fast_tanh(args[0]); - }); -TVM_REGISTER_GLOBAL("topi.atan") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = atan(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.sigmoid") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sigmoid(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.sqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sqrt(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.rsqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { -*rv = rsqrt(args[0]); -}); - -TVM_REGISTER_GLOBAL("topi.log") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = log(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.identity") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = identity(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.negative") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = negative(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.clip") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = clip(args[0], args[1], args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.cast") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = cast(args[0], args[1]); - }); - - -TVM_REGISTER_GLOBAL("topi.reinterpret") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = reinterpret(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.elemwise_sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = elemwise_sum(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.sign") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sign(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.full") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = full(args[0], args[1], args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.full_like") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = full_like(args[0], args[1]); - }); - -/* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = relu(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = leaky_relu(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.prelu") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = prelu(args[0], args[1], args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.pad") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = pad(args[0], args[1], args[2], args[3]); - }); - -/* Ops from reduction.h */ -TVM_REGISTER_GLOBAL("topi.sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.min") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.max") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.argmin") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.argmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.prod") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.all") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.any") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); - }); - -/* Ops from transform.h */ -TVM_REGISTER_GLOBAL("topi.expand_dims") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = expand_dims(args[0], args[1], args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.transpose") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = transpose(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.flip") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = flip(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.reshape") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = reshape(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.squeeze") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = squeeze(args[0], ArrayOrInt(args[1])); - }); - -TVM_REGISTER_GLOBAL("topi.concatenate") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = concatenate(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.stack") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = stack(args[0], args[1]); -}); - -TVM_REGISTER_GLOBAL("topi.shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = shape(args[0], args[1]); -}); - -TVM_REGISTER_GLOBAL("topi.ndarray_size") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = ndarray_size(args[0], args[1]); -}); - -TVM_REGISTER_GLOBAL("topi.split") -.set_body([](TVMArgs args, TVMRetValue *rv) { - if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { - *rv = split_sections(args[0], args[1], args[2]); - } else { - *rv = split(args[0], args[1], args[2]); - } -}); - -TVM_REGISTER_GLOBAL("topi.layout_transform") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = layout_transform(args[0], args[1], args[2]); -}); - -TVM_REGISTER_GLOBAL("topi.take") -.set_body([](TVMArgs args, TVMRetValue *rv) { - if (args.size() == 3) { - std::string mode = args[2]; - *rv = take(args[0], args[1], mode); - } else { - int axis = args[2]; - std::string mode = args[3]; - *rv = take(args[0], args[1], axis, mode); - } - }); - -TVM_REGISTER_GLOBAL("topi.sequence_mask") -.set_body([](TVMArgs args, TVMRetValue *rv) { - double pad_val = args[2]; - int axis = args[3]; - *rv = sequence_mask(args[0], args[1], pad_val, axis); -}); - - -TVM_REGISTER_GLOBAL("topi.where") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = where(args[0], args[1], args[2]); -}); - -TVM_REGISTER_GLOBAL("topi.arange") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = arange(args[0], args[1], args[2], args[3]); -}); - -TVM_REGISTER_GLOBAL("topi.repeat") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = repeat(args[0], args[1], args[2]); -}); - -TVM_REGISTER_GLOBAL("topi.tile") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tile(args[0], args[1]); -}); - -TVM_REGISTER_GLOBAL("topi.gather_nd") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = gather_nd(args[0], args[1]); -}); - -TVM_REGISTER_GLOBAL("topi.unravel_index") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = unravel_index(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { - switch ( args.size() ) { - case 2: *rv = matmul(args[0], args[1]); break; - case 3: *rv = matmul(args[0], args[1], args[2]); break; - case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break; - default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; - }}); - -TVM_REGISTER_GLOBAL("topi.tensordot") -.set_body([](TVMArgs args, TVMRetValue *rv) { - if (args.size() == 2) { - *rv = tensordot(args[0], args[1]); - } else if (args.size() == 3) { - *rv = tensordot(args[0], args[1], args[2]); - } else { - Array axes = args[3]; - *rv = tensordot(args[0], args[1], args[2], axes); - } - }); - -TVM_REGISTER_GLOBAL("topi.strided_slice") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3]); - }); - -TVM_REGISTER_GLOBAL("topi.one_hot") -.set_body([](TVMArgs args, TVMRetValue *rv) { - int depth = args[3]; - int axis = args[4]; - DataType dtype = args[5]; - *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); - }); - -/* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::binarize_pack(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::binary_dense(args[0], args[1]); - }); - -/* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::dense(args[0], args[1], args[2], args[3]); - }); - -/* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::bias_add(args[0], args[1], args[2]); - }); - -/* Ops from nn/batch_matmul.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::batch_matmul(args[0], args[1]); - }); - -/* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::dilate(args[0], args[1]); - }); - -/* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::flatten(args[0]); - }); - -/* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::scale_shift_nchw(args[0], args[1], args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::scale_shift_nhwc(args[0], args[1], args[2]); - }); - -/* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::pool(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.pool_grad") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4], - static_cast(static_cast(args[5])), - args[6], args[7], args[8]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::global_pool(args[0], - static_cast(static_cast(args[1])), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool(args[0], args[1], - static_cast(static_cast(args[2])), - args[3]); -}); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool3d(args[0], args[1], - static_cast(static_cast(args[2])), - args[3]); -}); - -TVM_REGISTER_GLOBAL("topi.nn.pool1d") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::pool1d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::pool3d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); - -/* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::softmax(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.log_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::log_softmax(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::lrn(args[0], args[1], args[2], - static_cast(args[3]), - static_cast(args[4]), - static_cast(args[5])); - }); - -TVM_REGISTER_GLOBAL("topi.vision.reorg") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = vision::reorg(args[0], args[1]); - }); - -/* Generic schedules */ -TVM_REGISTER_GLOBAL("topi.generic.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { - if (args[2]) { - *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]); - } else { - *rv = topi::generic::default_schedule(args[0], args[1]); - } - }); - -TVM_REGISTER_GLOBAL("topi.generic.schedule_extern") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::generic::schedule_extern(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::generic::schedule_injective(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); - }); - -/* x86 schedules */ -TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_binarize_pack(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_binary_dense(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.x86.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { - if (args[2]) { - *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]); - } else { - *rv = topi::x86::default_schedule(args[0], args[1]); - } - }); - -TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_injective(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); - }); - -/* ROCm schedules */ -TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_dense(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_injective(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_pool(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_global_pool(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_reduce(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_softmax(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_lrn(args[0]); - }); - -/* CUDA schedules */ -TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_dense(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_injective(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_pool(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_global_pool(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_reduce(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_softmax(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_lrn(args[0]); - }); - -/* Utility functions */ -TVM_REGISTER_GLOBAL("topi.util.is_empty_shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::detail::is_empty_shape(args[0]); - }); - -TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); - }); - -/*! \brief Builder function for instantiating schedules. */ -using FTVMScheduleBuilder = std::function< - tvm::te::Schedule(const tvm::Target& target, const tvm::Array& outs)>; - -/*! - * \brief Helper function for registering generic functions matching the - * FTVMScheduleBuilder signature. The schedule builder function is wrapped - * with a PackedFunc suitable for passing to a tvm::GenericFunc. - * - * \param builder The schedule builder to wrap. - * - * \return The wrapped schedule builder - */ -inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { - return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - auto target = Target::Current(false); - Array outs; - ObjectRef argNodeRef = args[0]; - if (argNodeRef->type_index() == outs->type_index()) { - outs = args[0]; - } else { - outs = Array { args[0] }; - } - - *ret = builder(target, outs); - }); -} - -TVM_REGISTER_GENERIC_FUNC(schedule_injective) -.set_default(WrapSchedule(topi::generic::schedule_injective)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective)); - -TVM_REGISTER_GENERIC_FUNC(schedule_softmax) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax)); - -TVM_REGISTER_GENERIC_FUNC(schedule_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) -.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); - -TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) -.set_default(WrapSchedule(topi::generic::default_schedule)); - -TVM_REGISTER_GENERIC_FUNC(schedule_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool)); - -TVM_REGISTER_GENERIC_FUNC(schedule_global_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool)); - -TVM_REGISTER_GENERIC_FUNC(schedule_reduce) -.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce)); - -TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack)); - -TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); - -/*! \brief Builder function for instantiating schedules from existing schedules. */ -using FTVMScheduleFromExistingBuilder = std::function< - tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>; - -/*! - * \brief Helper function for registering generic functions matching the - * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped - * with a PackedFunc suitable for passing to a tvm::GenericFunc. - * - * \param builder The schedule builder to wrap. - * - * \return The wrapped schedule builder - */ -inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { - return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - *ret = builder(args[0], args[1]); - }); -} - -TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) -.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) -.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) -.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting( - topi::cuda::schedule_injective_from_existing)); - -/*! \brief Builder function for instantiating dense ops. */ -using FTVMDenseOpBuilder = std::function; - -/*! -* \brief Helper function for registering dense ops matching the -* FTVMDenseOpBuilder signature. The op builder function is wrapped -* with a PackedFunc suitable for passing to a tvm::GenericFunc. -* -* \param builder The op builder to wrap. -* -* \return The wrapped op builder -*/ -inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { - return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - auto target = Target::Current(false); - Tensor data = args[0]; - Tensor weight = args[1]; - Tensor bias = args[2]; - DataType out_dtype = args[3]; - - *ret = builder(target, data, weight, bias, out_dtype); - }); -} - -TVM_REGISTER_GENERIC_FUNC(dense) -.set_default(WrapDenseOp([](const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { - return topi::nn::dense(data, weight, bias, out_dtype); -})) -.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda)) -.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm)); - -} // namespace topi diff --git a/topi/src/transform.cc b/topi/src/transform.cc new file mode 100644 index 0000000..4f0d4f8 --- /dev/null +++ b/topi/src/transform.cc @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of transform operators +* \file transform.cc +*/ +#include +#include + +#include +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +TVM_REGISTER_GLOBAL("topi.expand_dims") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = expand_dims(args[0], args[1], args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.transpose") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = transpose(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.flip") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = flip(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.reshape") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = reshape(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.squeeze") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = squeeze(args[0], ArrayOrInt(args[1])); + }); + +TVM_REGISTER_GLOBAL("topi.concatenate") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = concatenate(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.stack") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = stack(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("topi.shape") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = shape(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("topi.ndarray_size") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = ndarray_size(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("topi.split") +.set_body([](TVMArgs args, TVMRetValue *rv) { + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { + *rv = split_sections(args[0], args[1], args[2]); + } else { + *rv = split(args[0], args[1], args[2]); + } +}); + +TVM_REGISTER_GLOBAL("topi.layout_transform") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = layout_transform(args[0], args[1], args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.take") +.set_body([](TVMArgs args, TVMRetValue *rv) { + if (args.size() == 3) { + std::string mode = args[2]; + *rv = take(args[0], args[1], mode); + } else { + int axis = args[2]; + std::string mode = args[3]; + *rv = take(args[0], args[1], axis, mode); + } + }); + +TVM_REGISTER_GLOBAL("topi.sequence_mask") +.set_body([](TVMArgs args, TVMRetValue *rv) { + double pad_val = args[2]; + int axis = args[3]; + *rv = sequence_mask(args[0], args[1], pad_val, axis); +}); + +TVM_REGISTER_GLOBAL("topi.where") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = where(args[0], args[1], args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.arange") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = arange(args[0], args[1], args[2], args[3]); +}); + +TVM_REGISTER_GLOBAL("topi.repeat") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = repeat(args[0], args[1], args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.tile") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = tile(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("topi.gather_nd") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = gather_nd(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("topi.unravel_index") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = unravel_index(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.matmul") +.set_body([](TVMArgs args, TVMRetValue *rv) { + switch ( args.size() ) { + case 2: *rv = matmul(args[0], args[1]); break; + case 3: *rv = matmul(args[0], args[1], args[2]); break; + case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break; + default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; + }}); + +TVM_REGISTER_GLOBAL("topi.tensordot") +.set_body([](TVMArgs args, TVMRetValue *rv) { + if (args.size() == 2) { + *rv = tensordot(args[0], args[1]); + } else if (args.size() == 3) { + *rv = tensordot(args[0], args[1], args[2]); + } else { + Array axes = args[3]; + *rv = tensordot(args[0], args[1], args[2], axes); + } + }); + +TVM_REGISTER_GLOBAL("topi.strided_slice") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = strided_slice(args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_GLOBAL("topi.one_hot") +.set_body([](TVMArgs args, TVMRetValue *rv) { + int depth = args[3]; + int axis = args[4]; + DataType dtype = args[5]; + *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); + }); + +} // namespace topi diff --git a/topi/src/vision.cc b/topi/src/vision.cc new file mode 100644 index 0000000..1a4884e --- /dev/null +++ b/topi/src/vision.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* \brief Registration of vision operators +* \file vision.cc +*/ +#include +#include + +#include + +namespace topi { + +using namespace tvm; +using namespace tvm::runtime; + +TVM_REGISTER_GLOBAL("topi.vision.reorg") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = vision::reorg(args[0], args[1]); + }); + +} // namespace topi -- 2.7.4